Use openai compatible as vendor
This commit is contained in:
@@ -16,16 +16,22 @@ from ..schemas import (
|
||||
|
||||
router = APIRouter(prefix="/asr", tags=["ASR Models"])
|
||||
|
||||
SILICONFLOW_DEFAULT_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall"
|
||||
OPENAI_COMPATIBLE_DEFAULT_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall"
|
||||
|
||||
|
||||
def _is_siliconflow_vendor(vendor: str) -> bool:
|
||||
return (vendor or "").strip().lower() in {"siliconflow", "硅基流动"}
|
||||
def _is_openai_compatible_vendor(vendor: str) -> bool:
|
||||
normalized = (vendor or "").strip().lower()
|
||||
return normalized in {
|
||||
"openai compatible",
|
||||
"openai-compatible",
|
||||
"siliconflow", # backward compatibility
|
||||
"硅基流动", # backward compatibility
|
||||
}
|
||||
|
||||
|
||||
def _default_asr_model(vendor: str) -> str:
|
||||
if _is_siliconflow_vendor(vendor):
|
||||
return SILICONFLOW_DEFAULT_ASR_MODEL
|
||||
if _is_openai_compatible_vendor(vendor):
|
||||
return OPENAI_COMPATIBLE_DEFAULT_ASR_MODEL
|
||||
return "whisper-1"
|
||||
|
||||
|
||||
@@ -129,7 +135,7 @@ def test_asr_model(
|
||||
# 连接性测试优先,避免依赖真实音频输入
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
if model.vendor.lower() in ["siliconflow", "paraformer"]:
|
||||
if _is_openai_compatible_vendor(model.vendor) or model.vendor.lower() == "paraformer":
|
||||
response = client.get(f"{model.base_url}/asr", headers=headers)
|
||||
elif model.vendor.lower() == "openai":
|
||||
response = client.get(f"{model.base_url}/audio/models", headers=headers)
|
||||
@@ -258,7 +264,7 @@ async def preview_asr_model(
|
||||
raise HTTPException(status_code=400, detail="Uploaded audio file is empty")
|
||||
|
||||
effective_api_key = (api_key or "").strip() or (model.api_key or "").strip()
|
||||
if not effective_api_key and _is_siliconflow_vendor(model.vendor):
|
||||
if not effective_api_key and _is_openai_compatible_vendor(model.vendor):
|
||||
effective_api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
|
||||
if not effective_api_key:
|
||||
raise HTTPException(status_code=400, detail=f"API key is required for ASR model: {model.name}")
|
||||
|
||||
@@ -13,8 +13,13 @@ from ..schemas import (
|
||||
router = APIRouter(prefix="/assistants", tags=["Assistants"])
|
||||
|
||||
|
||||
def _is_siliconflow_vendor(vendor: Optional[str]) -> bool:
|
||||
return (vendor or "").strip().lower() in {"siliconflow", "硅基流动"}
|
||||
def _is_openai_compatible_vendor(vendor: Optional[str]) -> bool:
|
||||
return (vendor or "").strip().lower() in {
|
||||
"siliconflow",
|
||||
"硅基流动",
|
||||
"openai compatible",
|
||||
"openai-compatible",
|
||||
}
|
||||
|
||||
|
||||
def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
|
||||
@@ -47,11 +52,11 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
|
||||
if assistant.asr_model_id:
|
||||
asr = db.query(ASRModel).filter(ASRModel.id == assistant.asr_model_id).first()
|
||||
if asr:
|
||||
asr_provider = "siliconflow" if _is_siliconflow_vendor(asr.vendor) else "buffered"
|
||||
asr_provider = "openai_compatible" if _is_openai_compatible_vendor(asr.vendor) else "buffered"
|
||||
metadata["services"]["asr"] = {
|
||||
"provider": asr_provider,
|
||||
"model": asr.model_name or asr.name,
|
||||
"apiKey": asr.api_key if asr_provider == "siliconflow" else None,
|
||||
"apiKey": asr.api_key if asr_provider == "openai_compatible" else None,
|
||||
}
|
||||
else:
|
||||
warnings.append(f"ASR model not found: {assistant.asr_model_id}")
|
||||
@@ -61,12 +66,12 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
|
||||
elif assistant.voice:
|
||||
voice = db.query(Voice).filter(Voice.id == assistant.voice).first()
|
||||
if voice:
|
||||
tts_provider = "siliconflow" if _is_siliconflow_vendor(voice.vendor) else "edge"
|
||||
tts_provider = "openai_compatible" if _is_openai_compatible_vendor(voice.vendor) else "edge"
|
||||
metadata["services"]["tts"] = {
|
||||
"enabled": True,
|
||||
"provider": tts_provider,
|
||||
"model": voice.model,
|
||||
"apiKey": voice.api_key if tts_provider == "siliconflow" else None,
|
||||
"apiKey": voice.api_key if tts_provider == "openai_compatible" else None,
|
||||
"voice": voice.voice_key or voice.id,
|
||||
"speed": assistant.speed or voice.speed,
|
||||
}
|
||||
|
||||
@@ -467,7 +467,13 @@ def _test_asr_model(db: Session, model_id: str, result: AutotestResult):
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
|
||||
with httpx.Client(timeout=30.0) as client:
|
||||
if model.vendor.lower() in ["siliconflow", "paraformer"]:
|
||||
normalized_vendor = (model.vendor or "").strip().lower()
|
||||
if normalized_vendor in [
|
||||
"openai compatible",
|
||||
"openai-compatible",
|
||||
"siliconflow", # backward compatibility
|
||||
"paraformer",
|
||||
]:
|
||||
response = client.get(
|
||||
f"{model.base_url}/asr",
|
||||
headers=headers
|
||||
|
||||
@@ -13,20 +13,26 @@ from ..schemas import VoiceCreate, VoiceOut, VoicePreviewRequest, VoicePreviewRe
|
||||
|
||||
router = APIRouter(prefix="/voices", tags=["Voices"])
|
||||
|
||||
SILICONFLOW_DEFAULT_MODEL = "FunAudioLLM/CosyVoice2-0.5B"
|
||||
OPENAI_COMPATIBLE_DEFAULT_MODEL = "FunAudioLLM/CosyVoice2-0.5B"
|
||||
|
||||
|
||||
def _is_siliconflow_vendor(vendor: str) -> bool:
|
||||
return vendor.strip().lower() in {"siliconflow", "硅基流动"}
|
||||
def _is_openai_compatible_vendor(vendor: str) -> bool:
|
||||
normalized = (vendor or "").strip().lower()
|
||||
return normalized in {
|
||||
"openai compatible",
|
||||
"openai-compatible",
|
||||
"siliconflow", # backward compatibility
|
||||
"硅基流动", # backward compatibility
|
||||
}
|
||||
|
||||
|
||||
def _default_base_url(vendor: str) -> Optional[str]:
|
||||
if _is_siliconflow_vendor(vendor):
|
||||
if _is_openai_compatible_vendor(vendor):
|
||||
return "https://api.siliconflow.cn/v1"
|
||||
return None
|
||||
|
||||
|
||||
def _build_siliconflow_voice_key(voice: Voice, model: str) -> str:
|
||||
def _build_openai_compatible_voice_key(voice: Voice, model: str) -> str:
|
||||
if voice.voice_key:
|
||||
return voice.voice_key
|
||||
if ":" in voice.id:
|
||||
@@ -65,8 +71,8 @@ def create_voice(data: VoiceCreate, db: Session = Depends(get_db)):
|
||||
model = data.model
|
||||
voice_key = data.voice_key
|
||||
|
||||
if _is_siliconflow_vendor(vendor):
|
||||
model = model or SILICONFLOW_DEFAULT_MODEL
|
||||
if _is_openai_compatible_vendor(vendor):
|
||||
model = model or OPENAI_COMPATIBLE_DEFAULT_MODEL
|
||||
if not voice_key:
|
||||
raw_id = (data.id or data.name).strip()
|
||||
voice_key = raw_id if ":" in raw_id else f"{model}:{raw_id}"
|
||||
@@ -115,11 +121,11 @@ def update_voice(id: str, data: VoiceUpdate, db: Session = Depends(get_db)):
|
||||
update_data["vendor"] = update_data["vendor"].strip()
|
||||
|
||||
vendor_for_defaults = update_data.get("vendor", voice.vendor)
|
||||
if _is_siliconflow_vendor(vendor_for_defaults):
|
||||
model = update_data.get("model") or voice.model or SILICONFLOW_DEFAULT_MODEL
|
||||
if _is_openai_compatible_vendor(vendor_for_defaults):
|
||||
model = update_data.get("model") or voice.model or OPENAI_COMPATIBLE_DEFAULT_MODEL
|
||||
voice_key = update_data.get("voice_key") or voice.voice_key
|
||||
update_data["model"] = model
|
||||
update_data["voice_key"] = voice_key or _build_siliconflow_voice_key(voice, model)
|
||||
update_data["voice_key"] = voice_key or _build_openai_compatible_voice_key(voice, model)
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(voice, field, value)
|
||||
@@ -152,7 +158,7 @@ def preview_voice(id: str, data: VoicePreviewRequest, db: Session = Depends(get_
|
||||
raise HTTPException(status_code=400, detail="Preview text cannot be empty")
|
||||
|
||||
api_key = (data.api_key or "").strip() or (voice.api_key or "").strip()
|
||||
if not api_key and _is_siliconflow_vendor(voice.vendor):
|
||||
if not api_key and _is_openai_compatible_vendor(voice.vendor):
|
||||
api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=400, detail=f"API key is required for voice: {voice.name}")
|
||||
@@ -161,11 +167,11 @@ def preview_voice(id: str, data: VoicePreviewRequest, db: Session = Depends(get_
|
||||
if not base_url:
|
||||
raise HTTPException(status_code=400, detail=f"Base URL is required for voice: {voice.name}")
|
||||
|
||||
model = voice.model or SILICONFLOW_DEFAULT_MODEL
|
||||
model = voice.model or OPENAI_COMPATIBLE_DEFAULT_MODEL
|
||||
payload = {
|
||||
"model": model,
|
||||
"input": text,
|
||||
"voice": voice.voice_key or _build_siliconflow_voice_key(voice, model),
|
||||
"voice": voice.voice_key or _build_openai_compatible_voice_key(voice, model),
|
||||
"response_format": "mp3",
|
||||
"speed": data.speed if data.speed is not None else voice.speed,
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ interface ASRModel {
|
||||
id: string; // 模型唯一标识 (8位UUID)
|
||||
user_id: number; // 所属用户ID
|
||||
name: string; // 模型显示名称
|
||||
vendor: string; // 供应商: "OpenAI" | "SiliconFlow" | "Paraformer" | 等
|
||||
vendor: string; // 供应商: "OpenAI Compatible" | "Paraformer" | 等
|
||||
language: string; // 识别语言: "zh" | "en" | "Multi-lingual"
|
||||
base_url: string; // API Base URL
|
||||
api_key: string; // API Key
|
||||
@@ -64,7 +64,7 @@ GET /api/v1/asr
|
||||
"id": "abc12345",
|
||||
"user_id": 1,
|
||||
"name": "Whisper 多语种识别",
|
||||
"vendor": "OpenAI",
|
||||
"vendor": "OpenAI Compatible",
|
||||
"language": "Multi-lingual",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-***",
|
||||
@@ -78,7 +78,7 @@ GET /api/v1/asr
|
||||
"id": "def67890",
|
||||
"user_id": 1,
|
||||
"name": "SenseVoice 中文识别",
|
||||
"vendor": "SiliconFlow",
|
||||
"vendor": "OpenAI Compatible",
|
||||
"language": "zh",
|
||||
"base_url": "https://api.siliconflow.cn/v1",
|
||||
"api_key": "sf-***",
|
||||
@@ -114,7 +114,7 @@ GET /api/v1/asr/{id}
|
||||
"id": "abc12345",
|
||||
"user_id": 1,
|
||||
"name": "Whisper 多语种识别",
|
||||
"vendor": "OpenAI",
|
||||
"vendor": "OpenAI Compatible",
|
||||
"language": "Multi-lingual",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-***",
|
||||
@@ -140,7 +140,7 @@ POST /api/v1/asr
|
||||
```json
|
||||
{
|
||||
"name": "SenseVoice 中文识别",
|
||||
"vendor": "SiliconFlow",
|
||||
"vendor": "OpenAI Compatible",
|
||||
"language": "zh",
|
||||
"base_url": "https://api.siliconflow.cn/v1",
|
||||
"api_key": "sk-your-api-key",
|
||||
@@ -157,7 +157,7 @@ POST /api/v1/asr
|
||||
| 字段 | 类型 | 必填 | 说明 |
|
||||
|------|------|------|------|
|
||||
| name | string | 是 | 模型显示名称 |
|
||||
| vendor | string | 是 | 供应商: "OpenAI" / "SiliconFlow" / "Paraformer" |
|
||||
| vendor | string | 是 | 供应商: "OpenAI Compatible" / "Paraformer" |
|
||||
| language | string | 是 | 语言: "zh" / "en" / "Multi-lingual" |
|
||||
| base_url | string | 是 | API Base URL |
|
||||
| api_key | string | 是 | API Key |
|
||||
@@ -347,7 +347,7 @@ class ASRTestResponse(BaseModel):
|
||||
|
||||
```json
|
||||
{
|
||||
"vendor": "OpenAI",
|
||||
"vendor": "OpenAI Compatible",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-xxx",
|
||||
"model_name": "whisper-1",
|
||||
@@ -357,11 +357,11 @@ class ASRTestResponse(BaseModel):
|
||||
}
|
||||
```
|
||||
|
||||
### SiliconFlow Paraformer
|
||||
### OpenAI Compatible Paraformer
|
||||
|
||||
```json
|
||||
{
|
||||
"vendor": "SiliconFlow",
|
||||
"vendor": "OpenAI Compatible",
|
||||
"base_url": "https://api.siliconflow.cn/v1",
|
||||
"api_key": "sf-xxx",
|
||||
"model_name": "paraformer-v2",
|
||||
@@ -393,7 +393,7 @@ class ASRTestResponse(BaseModel):
|
||||
| test_filter_asr_models_by_language | 按语言过滤测试 |
|
||||
| test_filter_asr_models_by_enabled | 按启用状态过滤测试 |
|
||||
| test_create_asr_model_with_hotwords | 热词配置测试 |
|
||||
| test_test_asr_model_siliconflow | SiliconFlow 供应商测试 |
|
||||
| test_test_asr_model_siliconflow | OpenAI Compatible 供应商测试 |
|
||||
| test_test_asr_model_openai | OpenAI 供应商测试 |
|
||||
| test_different_asr_languages | 多语言测试 |
|
||||
| test_different_asr_vendors | 多供应商测试 |
|
||||
|
||||
@@ -20,7 +20,7 @@ interface LLMModel {
|
||||
id: string; // 模型唯一标识 (8位UUID)
|
||||
user_id: number; // 所属用户ID
|
||||
name: string; // 模型显示名称
|
||||
vendor: string; // 供应商: "OpenAI" | "SiliconFlow" | "Dify" | "FastGPT" | 等
|
||||
vendor: string; // 供应商: "OpenAI Compatible" | "Dify" | "FastGPT" | 等
|
||||
type: string; // 类型: "text" | "embedding" | "rerank"
|
||||
base_url: string; // API Base URL
|
||||
api_key: string; // API Key
|
||||
@@ -64,7 +64,7 @@ GET /api/v1/llm
|
||||
"id": "abc12345",
|
||||
"user_id": 1,
|
||||
"name": "GPT-4o",
|
||||
"vendor": "OpenAI",
|
||||
"vendor": "OpenAI Compatible",
|
||||
"type": "text",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-***",
|
||||
@@ -79,7 +79,7 @@ GET /api/v1/llm
|
||||
"id": "def67890",
|
||||
"user_id": 1,
|
||||
"name": "Embedding-3-Small",
|
||||
"vendor": "OpenAI",
|
||||
"vendor": "OpenAI Compatible",
|
||||
"type": "embedding",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-***",
|
||||
@@ -111,7 +111,7 @@ GET /api/v1/llm/{id}
|
||||
"id": "abc12345",
|
||||
"user_id": 1,
|
||||
"name": "GPT-4o",
|
||||
"vendor": "OpenAI",
|
||||
"vendor": "OpenAI Compatible",
|
||||
"type": "text",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-***",
|
||||
@@ -137,7 +137,7 @@ POST /api/v1/llm
|
||||
```json
|
||||
{
|
||||
"name": "GPT-4o",
|
||||
"vendor": "OpenAI",
|
||||
"vendor": "OpenAI Compatible",
|
||||
"type": "text",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-your-api-key",
|
||||
@@ -314,11 +314,11 @@ class LLMModelTestResponse(BaseModel):
|
||||
|
||||
## 供应商配置示例
|
||||
|
||||
### OpenAI
|
||||
### OpenAI Compatible (OpenAI Endpoint)
|
||||
|
||||
```json
|
||||
{
|
||||
"vendor": "OpenAI",
|
||||
"vendor": "OpenAI Compatible",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-xxx",
|
||||
"model_name": "gpt-4o",
|
||||
@@ -327,11 +327,11 @@ class LLMModelTestResponse(BaseModel):
|
||||
}
|
||||
```
|
||||
|
||||
### SiliconFlow
|
||||
### OpenAI Compatible
|
||||
|
||||
```json
|
||||
{
|
||||
"vendor": "SiliconFlow",
|
||||
"vendor": "OpenAI Compatible",
|
||||
"base_url": "https://api.siliconflow.com/v1",
|
||||
"api_key": "sf-xxx",
|
||||
"model_name": "deepseek-v3",
|
||||
@@ -356,7 +356,7 @@ class LLMModelTestResponse(BaseModel):
|
||||
|
||||
```json
|
||||
{
|
||||
"vendor": "OpenAI",
|
||||
"vendor": "OpenAI Compatible",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-xxx",
|
||||
"model_name": "text-embedding-3-small",
|
||||
|
||||
@@ -20,7 +20,7 @@ interface LLMModel {
|
||||
id: string; // 模型唯一标识
|
||||
user_id: number; // 所属用户ID
|
||||
name: string; // 模型显示名称
|
||||
vendor: string; // 供应商: "OpenAI Compatible" | "SiliconFlow" | "Dify" | "FastGPT"
|
||||
vendor: string; // 供应商: "OpenAI Compatible" | "Dify" | "FastGPT"
|
||||
type: string; // 类型: "text" | "embedding" | "rerank"
|
||||
base_url: string; // API Base URL
|
||||
api_key: string; // API Key
|
||||
@@ -57,7 +57,7 @@ interface TTSModel {
|
||||
id: string;
|
||||
user_id: number;
|
||||
name: string;
|
||||
vendor: string; // "Ali" | "Volcano" | "Minimax" | "硅基流动"
|
||||
vendor: string; // "OpenAI Compatible" | "Ali" | "Volcano" | "Minimax"
|
||||
language: string; // "zh" | "en"
|
||||
voice_list?: string[]; // 支持的声音列表
|
||||
enabled: boolean;
|
||||
@@ -316,7 +316,6 @@ class LLMModelType(str, Enum):
|
||||
|
||||
class LLMModelVendor(str, Enum):
|
||||
OPENAI_COMPATIBLE = "OpenAI Compatible"
|
||||
SILICONFLOW = "SiliconFlow"
|
||||
DIFY = "Dify"
|
||||
FASTGPT = "FastGPT"
|
||||
|
||||
@@ -389,11 +388,11 @@ class ASRModelOut(ASRModelBase):
|
||||
}
|
||||
```
|
||||
|
||||
### SiliconFlow
|
||||
### OpenAI Compatible
|
||||
|
||||
```json
|
||||
{
|
||||
"vendor": "SiliconFlow",
|
||||
"vendor": "OpenAI Compatible",
|
||||
"base_url": "https://api.siliconflow.com/v1",
|
||||
"api_key": "sf-xxx",
|
||||
"model_name": "deepseek-v3"
|
||||
|
||||
@@ -135,21 +135,21 @@ def rebuild_vector_store(reset_doc_status: bool = True):
|
||||
def init_default_data():
|
||||
with db_session() as db:
|
||||
# 检查是否已有数据
|
||||
# SiliconFlow CosyVoice 2.0 预设声音 (8个)
|
||||
# OpenAI Compatible (SiliconFlow API) CosyVoice 2.0 预设声音 (8个)
|
||||
# 参考: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech
|
||||
voices = [
|
||||
# 男声 (Male Voices)
|
||||
Voice(id="alex", name="Alex", vendor="SiliconFlow", gender="Male", language="en",
|
||||
Voice(id="alex", name="Alex", vendor="OpenAI Compatible", gender="Male", language="en",
|
||||
description="Steady male voice.", is_system=True),
|
||||
Voice(id="david", name="David", vendor="SiliconFlow", gender="Male", language="en",
|
||||
Voice(id="david", name="David", vendor="OpenAI Compatible", gender="Male", language="en",
|
||||
description="Cheerful male voice.", is_system=True),
|
||||
# 女声 (Female Voices)
|
||||
Voice(id="bella", name="Bella", vendor="SiliconFlow", gender="Female", language="en",
|
||||
Voice(id="bella", name="Bella", vendor="OpenAI Compatible", gender="Female", language="en",
|
||||
description="Passionate female voice.", is_system=True),
|
||||
Voice(id="claire", name="Claire", vendor="SiliconFlow", gender="Female", language="en",
|
||||
Voice(id="claire", name="Claire", vendor="OpenAI Compatible", gender="Female", language="en",
|
||||
description="Gentle female voice.", is_system=True),
|
||||
]
|
||||
seed_if_empty(db, Voice, voices, "✅ 默认声音数据已初始化 (SiliconFlow CosyVoice 2.0)")
|
||||
seed_if_empty(db, Voice, voices, "✅ 默认声音数据已初始化 (OpenAI Compatible CosyVoice 2.0)")
|
||||
|
||||
|
||||
def init_default_tools(recreate: bool = False):
|
||||
@@ -181,7 +181,7 @@ def init_default_assistants():
|
||||
voice="anna",
|
||||
speed=1.0,
|
||||
hotwords=[],
|
||||
tools=["calculator", "current_time"],
|
||||
tools=["current_time"],
|
||||
interruption_sensitivity=500,
|
||||
config_mode="platform",
|
||||
llm_model_id="deepseek-chat",
|
||||
@@ -215,7 +215,7 @@ def init_default_assistants():
|
||||
voice="alex",
|
||||
speed=1.0,
|
||||
hotwords=["grammar", "vocabulary", "practice"],
|
||||
tools=["calculator"],
|
||||
tools=["current_time"],
|
||||
interruption_sensitivity=400,
|
||||
config_mode="platform",
|
||||
),
|
||||
@@ -294,7 +294,7 @@ def init_default_llm_models():
|
||||
id="deepseek-chat",
|
||||
user_id=1,
|
||||
name="DeepSeek Chat",
|
||||
vendor="SiliconFlow",
|
||||
vendor="OpenAI Compatible",
|
||||
type="text",
|
||||
base_url="https://api.deepseek.com",
|
||||
api_key="YOUR_API_KEY", # 用户需替换
|
||||
@@ -320,7 +320,7 @@ def init_default_llm_models():
|
||||
id="text-embedding-3-small",
|
||||
user_id=1,
|
||||
name="Embedding 3 Small",
|
||||
vendor="OpenAI",
|
||||
vendor="OpenAI Compatible",
|
||||
type="embedding",
|
||||
base_url="https://api.openai.com/v1",
|
||||
api_key="YOUR_API_KEY",
|
||||
@@ -339,7 +339,7 @@ def init_default_asr_models():
|
||||
id="FunAudioLLM/SenseVoiceSmall",
|
||||
user_id=1,
|
||||
name="FunAudioLLM/SenseVoiceSmall",
|
||||
vendor="SiliconFlow",
|
||||
vendor="OpenAI Compatible",
|
||||
language="Multi-lingual",
|
||||
base_url="https://api.siliconflow.cn/v1",
|
||||
api_key="YOUR_API_KEY",
|
||||
@@ -353,7 +353,7 @@ def init_default_asr_models():
|
||||
id="TeleAI/TeleSpeechASR",
|
||||
user_id=1,
|
||||
name="TeleAI/TeleSpeechASR",
|
||||
vendor="SiliconFlow",
|
||||
vendor="OpenAI Compatible",
|
||||
language="Multi-lingual",
|
||||
base_url="https://api.siliconflow.cn/v1",
|
||||
api_key="YOUR_API_KEY",
|
||||
|
||||
@@ -41,19 +41,19 @@ LLM_MODEL=gpt-4o-mini
|
||||
LLM_TEMPERATURE=0.7
|
||||
|
||||
# TTS
|
||||
# edge: no SiliconFlow key needed
|
||||
# siliconflow: requires SILICONFLOW_API_KEY
|
||||
TTS_PROVIDER=siliconflow
|
||||
# edge: no API key needed
|
||||
# openai_compatible: compatible with SiliconFlow-style endpoints
|
||||
TTS_PROVIDER=openai_compatible
|
||||
TTS_VOICE=anna
|
||||
TTS_SPEED=1.0
|
||||
|
||||
# SiliconFlow (used by TTS and/or ASR when provider=siliconflow)
|
||||
# SiliconFlow (used by TTS and/or ASR when provider=openai_compatible)
|
||||
SILICONFLOW_API_KEY=your_siliconflow_api_key_here
|
||||
SILICONFLOW_TTS_MODEL=FunAudioLLM/CosyVoice2-0.5B
|
||||
SILICONFLOW_ASR_MODEL=FunAudioLLM/SenseVoiceSmall
|
||||
|
||||
# ASR
|
||||
ASR_PROVIDER=siliconflow
|
||||
ASR_PROVIDER=openai_compatible
|
||||
# Interim cadence and minimum audio before interim decode.
|
||||
ASR_INTERIM_INTERVAL_MS=500
|
||||
ASR_MIN_AUDIO_MS=300
|
||||
|
||||
@@ -44,7 +44,10 @@ class Settings(BaseSettings):
|
||||
llm_temperature: float = Field(default=0.7, description="LLM temperature for response generation")
|
||||
|
||||
# TTS Configuration
|
||||
tts_provider: str = Field(default="siliconflow", description="TTS provider (edge, siliconflow)")
|
||||
tts_provider: str = Field(
|
||||
default="openai_compatible",
|
||||
description="TTS provider (edge, openai_compatible; siliconflow alias supported)"
|
||||
)
|
||||
tts_voice: str = Field(default="anna", description="TTS voice name")
|
||||
tts_speed: float = Field(default=1.0, description="TTS speech speed multiplier")
|
||||
|
||||
@@ -53,7 +56,10 @@ class Settings(BaseSettings):
|
||||
siliconflow_tts_model: str = Field(default="FunAudioLLM/CosyVoice2-0.5B", description="SiliconFlow TTS model")
|
||||
|
||||
# ASR Configuration
|
||||
asr_provider: str = Field(default="siliconflow", description="ASR provider (siliconflow, buffered)")
|
||||
asr_provider: str = Field(
|
||||
default="openai_compatible",
|
||||
description="ASR provider (openai_compatible, buffered; siliconflow alias supported)"
|
||||
)
|
||||
siliconflow_asr_model: str = Field(default="FunAudioLLM/SenseVoiceSmall", description="SiliconFlow ASR model")
|
||||
asr_interim_interval_ms: int = Field(default=500, description="Interval for interim ASR results in ms")
|
||||
asr_min_audio_ms: int = Field(default=300, description="Minimum audio duration before first ASR result")
|
||||
|
||||
@@ -30,8 +30,8 @@ from processors.vad import SileroVAD, VADProcessor
|
||||
from services.asr import BufferedASRService
|
||||
from services.base import BaseASRService, BaseLLMService, BaseTTSService, LLMMessage, LLMStreamEvent
|
||||
from services.llm import MockLLMService, OpenAILLMService
|
||||
from services.siliconflow_asr import SiliconFlowASRService
|
||||
from services.siliconflow_tts import SiliconFlowTTSService
|
||||
from services.openai_compatible_asr import OpenAICompatibleASRService
|
||||
from services.openai_compatible_tts import OpenAICompatibleTTSService
|
||||
from services.streaming_text import extract_tts_sentence, has_spoken_content
|
||||
from services.tts import EdgeTTSService, MockTTSService
|
||||
|
||||
@@ -60,57 +60,6 @@ class DuplexPipeline:
|
||||
_TOOL_WAIT_TIMEOUT_SECONDS = 15.0
|
||||
_SERVER_TOOL_TIMEOUT_SECONDS = 15.0
|
||||
_DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = {
|
||||
"search": {
|
||||
"name": "search",
|
||||
"description": "Search the internet for recent information",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
"calculator": {
|
||||
"name": "calculator",
|
||||
"description": "Evaluate a math expression",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"expression": {"type": "string"}},
|
||||
"required": ["expression"],
|
||||
},
|
||||
},
|
||||
"weather": {
|
||||
"name": "weather",
|
||||
"description": "Get weather by city name",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
"translate": {
|
||||
"name": "translate",
|
||||
"description": "Translate text to target language",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string"},
|
||||
"target_lang": {"type": "string"},
|
||||
},
|
||||
"required": ["text", "target_lang"],
|
||||
},
|
||||
},
|
||||
"knowledge": {
|
||||
"name": "knowledge",
|
||||
"description": "Query knowledge base by question",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"kb_id": {"type": "string"},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
"current_time": {
|
||||
"name": "current_time",
|
||||
"description": "Get current local time",
|
||||
@@ -120,51 +69,6 @@ class DuplexPipeline:
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
"code_interpreter": {
|
||||
"name": "code_interpreter",
|
||||
"description": "Execute Python code in a controlled environment",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"code": {"type": "string"}},
|
||||
"required": ["code"],
|
||||
},
|
||||
},
|
||||
"turn_on_camera": {
|
||||
"name": "turn_on_camera",
|
||||
"description": "Turn on camera on client device",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
"turn_off_camera": {
|
||||
"name": "turn_off_camera",
|
||||
"description": "Turn off camera on client device",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
"increase_volume": {
|
||||
"name": "increase_volume",
|
||||
"description": "Increase speaker volume",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"step": {"type": "integer"}},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
"decrease_volume": {
|
||||
"name": "decrease_volume",
|
||||
"description": "Decrease speaker volume",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"step": {"type": "integer"}},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(
|
||||
@@ -386,6 +290,11 @@ class DuplexPipeline:
|
||||
return False
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _is_openai_compatible_provider(provider: Any) -> bool:
|
||||
normalized = str(provider or "").strip().lower()
|
||||
return normalized in {"openai_compatible", "openai-compatible", "siliconflow"}
|
||||
|
||||
def _tts_output_enabled(self) -> bool:
|
||||
enabled = self._coerce_bool(self._runtime_tts.get("enabled"))
|
||||
if enabled is not None:
|
||||
@@ -495,15 +404,15 @@ class DuplexPipeline:
|
||||
tts_model = self._runtime_tts.get("model") or settings.siliconflow_tts_model
|
||||
tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed)
|
||||
|
||||
if tts_provider == "siliconflow" and tts_api_key:
|
||||
self.tts_service = SiliconFlowTTSService(
|
||||
if self._is_openai_compatible_provider(tts_provider) and tts_api_key:
|
||||
self.tts_service = OpenAICompatibleTTSService(
|
||||
api_key=tts_api_key,
|
||||
voice=tts_voice,
|
||||
model=tts_model,
|
||||
sample_rate=settings.sample_rate,
|
||||
speed=tts_speed
|
||||
)
|
||||
logger.info("Using SiliconFlow TTS service")
|
||||
logger.info("Using OpenAI-compatible TTS service (SiliconFlow implementation)")
|
||||
else:
|
||||
self.tts_service = EdgeTTSService(
|
||||
voice=tts_voice,
|
||||
@@ -531,8 +440,8 @@ class DuplexPipeline:
|
||||
asr_interim_interval = int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms)
|
||||
asr_min_audio_ms = int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms)
|
||||
|
||||
if asr_provider == "siliconflow" and asr_api_key:
|
||||
self.asr_service = SiliconFlowASRService(
|
||||
if self._is_openai_compatible_provider(asr_provider) and asr_api_key:
|
||||
self.asr_service = OpenAICompatibleASRService(
|
||||
api_key=asr_api_key,
|
||||
model=asr_model,
|
||||
sample_rate=settings.sample_rate,
|
||||
@@ -540,7 +449,7 @@ class DuplexPipeline:
|
||||
min_audio_for_interim_ms=asr_min_audio_ms,
|
||||
on_transcript=self._on_transcript_callback
|
||||
)
|
||||
logger.info("Using SiliconFlow ASR service")
|
||||
logger.info("Using OpenAI-compatible ASR service (SiliconFlow implementation)")
|
||||
else:
|
||||
self.asr_service = BufferedASRService(
|
||||
sample_rate=settings.sample_rate
|
||||
|
||||
@@ -66,7 +66,7 @@ Rules:
|
||||
"baseUrl": "https://api.openai.com/v1"
|
||||
},
|
||||
"asr": {
|
||||
"provider": "siliconflow",
|
||||
"provider": "openai_compatible",
|
||||
"model": "FunAudioLLM/SenseVoiceSmall",
|
||||
"apiKey": "sf-...",
|
||||
"interimIntervalMs": 500,
|
||||
@@ -74,7 +74,7 @@ Rules:
|
||||
},
|
||||
"tts": {
|
||||
"enabled": true,
|
||||
"provider": "siliconflow",
|
||||
"provider": "openai_compatible",
|
||||
"model": "FunAudioLLM/CosyVoice2-0.5B",
|
||||
"apiKey": "sf-...",
|
||||
"voice": "anna",
|
||||
|
||||
@@ -15,8 +15,8 @@ from services.base import (
|
||||
from services.llm import OpenAILLMService, MockLLMService
|
||||
from services.tts import EdgeTTSService, MockTTSService
|
||||
from services.asr import BufferedASRService, MockASRService
|
||||
from services.siliconflow_asr import SiliconFlowASRService
|
||||
from services.siliconflow_tts import SiliconFlowTTSService
|
||||
from services.openai_compatible_asr import OpenAICompatibleASRService, SiliconFlowASRService
|
||||
from services.openai_compatible_tts import OpenAICompatibleTTSService, SiliconFlowTTSService
|
||||
from services.streaming_tts_adapter import StreamingTTSAdapter
|
||||
from services.realtime import RealtimeService, RealtimeConfig, RealtimePipeline
|
||||
|
||||
@@ -38,8 +38,10 @@ __all__ = [
|
||||
# ASR
|
||||
"BufferedASRService",
|
||||
"MockASRService",
|
||||
"OpenAICompatibleASRService",
|
||||
"SiliconFlowASRService",
|
||||
# TTS (SiliconFlow)
|
||||
"OpenAICompatibleTTSService",
|
||||
"SiliconFlowTTSService",
|
||||
"StreamingTTSAdapter",
|
||||
# Realtime
|
||||
|
||||
321
engine/services/openai_compatible_asr.py
Normal file
321
engine/services/openai_compatible_asr.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""OpenAI-compatible ASR (Automatic Speech Recognition) Service.
|
||||
|
||||
Uses the SiliconFlow API for speech-to-text transcription.
|
||||
API: https://docs.siliconflow.cn/cn/api-reference/audio/create-audio-transcriptions
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import wave
|
||||
from typing import AsyncIterator, Optional, Callable, Awaitable
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
AIOHTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIOHTTP_AVAILABLE = False
|
||||
logger.warning("aiohttp not available - OpenAICompatibleASRService will not work")
|
||||
|
||||
from services.base import BaseASRService, ASRResult, ServiceState
|
||||
|
||||
|
||||
class OpenAICompatibleASRService(BaseASRService):
|
||||
"""
|
||||
OpenAI-compatible ASR service for speech-to-text transcription.
|
||||
|
||||
Features:
|
||||
- Buffers incoming audio chunks
|
||||
- Provides interim transcriptions periodically (for streaming to client)
|
||||
- Final transcription on EOU
|
||||
|
||||
API Details:
|
||||
- Endpoint: POST https://api.siliconflow.cn/v1/audio/transcriptions
|
||||
- Models: FunAudioLLM/SenseVoiceSmall (default), TeleAI/TeleSpeechASR
|
||||
- Input: Audio file (multipart/form-data)
|
||||
- Output: {"text": "transcribed text"}
|
||||
"""
|
||||
|
||||
# Supported models
|
||||
MODELS = {
|
||||
"sensevoice": "FunAudioLLM/SenseVoiceSmall",
|
||||
"telespeech": "TeleAI/TeleSpeechASR",
|
||||
}
|
||||
|
||||
API_URL = "https://api.siliconflow.cn/v1/audio/transcriptions"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str = "FunAudioLLM/SenseVoiceSmall",
|
||||
sample_rate: int = 16000,
|
||||
language: str = "auto",
|
||||
interim_interval_ms: int = 500, # How often to send interim results
|
||||
min_audio_for_interim_ms: int = 300, # Min audio before first interim
|
||||
on_transcript: Optional[Callable[[str, bool], Awaitable[None]]] = None
|
||||
):
|
||||
"""
|
||||
Initialize OpenAI-compatible ASR service.
|
||||
|
||||
Args:
|
||||
api_key: Provider API key
|
||||
model: ASR model name or alias
|
||||
sample_rate: Audio sample rate (16000 recommended)
|
||||
language: Language code (auto for automatic detection)
|
||||
interim_interval_ms: How often to generate interim transcriptions
|
||||
min_audio_for_interim_ms: Minimum audio duration before first interim
|
||||
on_transcript: Callback for transcription results (text, is_final)
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, language=language)
|
||||
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
raise RuntimeError("aiohttp is required for OpenAICompatibleASRService")
|
||||
|
||||
self.api_key = api_key
|
||||
self.model = self.MODELS.get(model.lower(), model)
|
||||
self.interim_interval_ms = interim_interval_ms
|
||||
self.min_audio_for_interim_ms = min_audio_for_interim_ms
|
||||
self.on_transcript = on_transcript
|
||||
|
||||
# Session
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
# Audio buffer
|
||||
self._audio_buffer: bytes = b""
|
||||
self._current_text: str = ""
|
||||
self._last_interim_time: float = 0
|
||||
|
||||
# Transcript queue for async iteration
|
||||
self._transcript_queue: asyncio.Queue[ASRResult] = asyncio.Queue()
|
||||
|
||||
# Background task for interim results
|
||||
self._interim_task: Optional[asyncio.Task] = None
|
||||
self._running = False
|
||||
|
||||
logger.info(f"OpenAICompatibleASRService initialized with model: {self.model}")
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to the service."""
|
||||
self._session = aiohttp.ClientSession(
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
)
|
||||
self._running = True
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info("OpenAICompatibleASRService connected")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect and cleanup."""
|
||||
self._running = False
|
||||
|
||||
if self._interim_task:
|
||||
self._interim_task.cancel()
|
||||
try:
|
||||
await self._interim_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._interim_task = None
|
||||
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
self._audio_buffer = b""
|
||||
self._current_text = ""
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("OpenAICompatibleASRService disconnected")
|
||||
|
||||
async def send_audio(self, audio: bytes) -> None:
|
||||
"""
|
||||
Buffer incoming audio data.
|
||||
|
||||
Args:
|
||||
audio: PCM audio data (16-bit, mono)
|
||||
"""
|
||||
self._audio_buffer += audio
|
||||
|
||||
async def transcribe_buffer(self, is_final: bool = False) -> Optional[str]:
|
||||
"""
|
||||
Transcribe current audio buffer.
|
||||
|
||||
Args:
|
||||
is_final: Whether this is the final transcription
|
||||
|
||||
Returns:
|
||||
Transcribed text or None if not enough audio
|
||||
"""
|
||||
if not self._session:
|
||||
logger.warning("ASR session not connected")
|
||||
return None
|
||||
|
||||
# Check minimum audio duration
|
||||
audio_duration_ms = len(self._audio_buffer) / (self.sample_rate * 2) * 1000
|
||||
|
||||
if not is_final and audio_duration_ms < self.min_audio_for_interim_ms:
|
||||
return None
|
||||
|
||||
if audio_duration_ms < 100: # Less than 100ms - too short
|
||||
return None
|
||||
|
||||
try:
|
||||
# Convert PCM to WAV in memory
|
||||
wav_buffer = io.BytesIO()
|
||||
with wave.open(wav_buffer, 'wb') as wav_file:
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.setsampwidth(2) # 16-bit
|
||||
wav_file.setframerate(self.sample_rate)
|
||||
wav_file.writeframes(self._audio_buffer)
|
||||
|
||||
wav_buffer.seek(0)
|
||||
wav_data = wav_buffer.read()
|
||||
|
||||
# Send to API
|
||||
form_data = aiohttp.FormData()
|
||||
form_data.add_field(
|
||||
'file',
|
||||
wav_data,
|
||||
filename='audio.wav',
|
||||
content_type='audio/wav'
|
||||
)
|
||||
form_data.add_field('model', self.model)
|
||||
|
||||
async with self._session.post(self.API_URL, data=form_data) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
text = result.get("text", "").strip()
|
||||
|
||||
if text:
|
||||
self._current_text = text
|
||||
|
||||
# Notify via callback
|
||||
if self.on_transcript:
|
||||
await self.on_transcript(text, is_final)
|
||||
|
||||
# Queue result
|
||||
await self._transcript_queue.put(
|
||||
ASRResult(text=text, is_final=is_final)
|
||||
)
|
||||
|
||||
logger.debug(f"ASR {'final' if is_final else 'interim'}: {text[:50]}...")
|
||||
return text
|
||||
else:
|
||||
error_text = await response.text()
|
||||
logger.error(f"ASR API error {response.status}: {error_text}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ASR transcription error: {e}")
|
||||
return None
|
||||
|
||||
async def get_final_transcription(self) -> str:
|
||||
"""
|
||||
Get final transcription and clear buffer.
|
||||
|
||||
Call this when EOU is detected.
|
||||
|
||||
Returns:
|
||||
Final transcribed text
|
||||
"""
|
||||
# Transcribe full buffer as final
|
||||
text = await self.transcribe_buffer(is_final=True)
|
||||
|
||||
# Clear buffer
|
||||
result = text or self._current_text
|
||||
self._audio_buffer = b""
|
||||
self._current_text = ""
|
||||
|
||||
return result
|
||||
|
||||
def get_and_clear_text(self) -> str:
|
||||
"""
|
||||
Get accumulated text and clear buffer.
|
||||
|
||||
Compatible with BufferedASRService interface.
|
||||
"""
|
||||
text = self._current_text
|
||||
self._current_text = ""
|
||||
self._audio_buffer = b""
|
||||
return text
|
||||
|
||||
def get_audio_buffer(self) -> bytes:
|
||||
"""Get current audio buffer."""
|
||||
return self._audio_buffer
|
||||
|
||||
def get_audio_duration_ms(self) -> float:
|
||||
"""Get current audio buffer duration in milliseconds."""
|
||||
return len(self._audio_buffer) / (self.sample_rate * 2) * 1000
|
||||
|
||||
def clear_buffer(self) -> None:
|
||||
"""Clear audio and text buffers."""
|
||||
self._audio_buffer = b""
|
||||
self._current_text = ""
|
||||
|
||||
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
|
||||
"""
|
||||
Async iterator for transcription results.
|
||||
|
||||
Yields:
|
||||
ASRResult with text and is_final flag
|
||||
"""
|
||||
while self._running:
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._transcript_queue.get(),
|
||||
timeout=0.1
|
||||
)
|
||||
yield result
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
async def start_interim_transcription(self) -> None:
|
||||
"""
|
||||
Start background task for interim transcriptions.
|
||||
|
||||
This periodically transcribes buffered audio for
|
||||
real-time feedback to the user.
|
||||
"""
|
||||
if self._interim_task and not self._interim_task.done():
|
||||
return
|
||||
|
||||
self._interim_task = asyncio.create_task(self._interim_loop())
|
||||
|
||||
async def stop_interim_transcription(self) -> None:
|
||||
"""Stop interim transcription task."""
|
||||
if self._interim_task:
|
||||
self._interim_task.cancel()
|
||||
try:
|
||||
await self._interim_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._interim_task = None
|
||||
|
||||
async def _interim_loop(self) -> None:
|
||||
"""Background loop for interim transcriptions."""
|
||||
import time
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(self.interim_interval_ms / 1000)
|
||||
|
||||
# Check if we have enough new audio
|
||||
current_time = time.time()
|
||||
time_since_last = (current_time - self._last_interim_time) * 1000
|
||||
|
||||
if time_since_last >= self.interim_interval_ms:
|
||||
audio_duration = self.get_audio_duration_ms()
|
||||
|
||||
if audio_duration >= self.min_audio_for_interim_ms:
|
||||
await self.transcribe_buffer(is_final=False)
|
||||
self._last_interim_time = current_time
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Interim transcription error: {e}")
|
||||
|
||||
|
||||
# Backward-compatible alias
|
||||
SiliconFlowASRService = OpenAICompatibleASRService
|
||||
315
engine/services/openai_compatible_tts.py
Normal file
315
engine/services/openai_compatible_tts.py
Normal file
@@ -0,0 +1,315 @@
|
||||
"""OpenAI-compatible TTS Service with streaming support.
|
||||
|
||||
Uses SiliconFlow's CosyVoice2 or MOSS-TTSD models for low-latency
|
||||
text-to-speech synthesis with streaming.
|
||||
|
||||
API Docs: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from typing import AsyncIterator, Optional
|
||||
from loguru import logger
|
||||
|
||||
from services.base import BaseTTSService, TTSChunk, ServiceState
|
||||
from services.streaming_tts_adapter import StreamingTTSAdapter # backward-compatible re-export
|
||||
|
||||
|
||||
class OpenAICompatibleTTSService(BaseTTSService):
|
||||
"""
|
||||
OpenAI-compatible TTS service with streaming support.
|
||||
|
||||
Supports CosyVoice2-0.5B and MOSS-TTSD-v0.5 models.
|
||||
"""
|
||||
|
||||
# Available voices
|
||||
VOICES = {
|
||||
"alex": "FunAudioLLM/CosyVoice2-0.5B:alex",
|
||||
"anna": "FunAudioLLM/CosyVoice2-0.5B:anna",
|
||||
"bella": "FunAudioLLM/CosyVoice2-0.5B:bella",
|
||||
"benjamin": "FunAudioLLM/CosyVoice2-0.5B:benjamin",
|
||||
"charles": "FunAudioLLM/CosyVoice2-0.5B:charles",
|
||||
"claire": "FunAudioLLM/CosyVoice2-0.5B:claire",
|
||||
"david": "FunAudioLLM/CosyVoice2-0.5B:david",
|
||||
"diana": "FunAudioLLM/CosyVoice2-0.5B:diana",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
voice: str = "anna",
|
||||
model: str = "FunAudioLLM/CosyVoice2-0.5B",
|
||||
sample_rate: int = 16000,
|
||||
speed: float = 1.0
|
||||
):
|
||||
"""
|
||||
Initialize OpenAI-compatible TTS service.
|
||||
|
||||
Args:
|
||||
api_key: Provider API key (defaults to SILICONFLOW_API_KEY env var)
|
||||
voice: Voice name (alex, anna, bella, benjamin, charles, claire, david, diana)
|
||||
model: Model name
|
||||
sample_rate: Output sample rate (8000, 16000, 24000, 32000, 44100)
|
||||
speed: Speech speed (0.25 to 4.0)
|
||||
"""
|
||||
# Resolve voice name
|
||||
if voice in self.VOICES:
|
||||
full_voice = self.VOICES[voice]
|
||||
else:
|
||||
full_voice = voice
|
||||
|
||||
super().__init__(voice=full_voice, sample_rate=sample_rate, speed=speed)
|
||||
|
||||
self.api_key = api_key or os.getenv("SILICONFLOW_API_KEY")
|
||||
self.model = model
|
||||
self.api_url = "https://api.siliconflow.cn/v1/audio/speech"
|
||||
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._cancel_event = asyncio.Event()
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize HTTP session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("SiliconFlow API key not provided. Set SILICONFLOW_API_KEY env var.")
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
)
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info(f"SiliconFlow TTS service ready: voice={self.voice}, model={self.model}")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Close HTTP session."""
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("SiliconFlow TTS service disconnected")
|
||||
|
||||
async def synthesize(self, text: str) -> bytes:
|
||||
"""Synthesize complete audio for text."""
|
||||
audio_data = b""
|
||||
async for chunk in self.synthesize_stream(text):
|
||||
audio_data += chunk.audio
|
||||
return audio_data
|
||||
|
||||
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
|
||||
"""
|
||||
Synthesize audio in streaming mode.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
|
||||
Yields:
|
||||
TTSChunk objects with PCM audio
|
||||
"""
|
||||
if not self._session:
|
||||
raise RuntimeError("TTS service not connected")
|
||||
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
self._cancel_event.clear()
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"input": text,
|
||||
"voice": self.voice,
|
||||
"response_format": "pcm",
|
||||
"sample_rate": self.sample_rate,
|
||||
"stream": True,
|
||||
"speed": self.speed
|
||||
}
|
||||
|
||||
try:
|
||||
async with self._session.post(self.api_url, json=payload) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"SiliconFlow TTS error: {response.status} - {error_text}")
|
||||
return
|
||||
|
||||
# Stream audio chunks
|
||||
chunk_size = self.sample_rate * 2 // 10 # 100ms chunks
|
||||
buffer = b""
|
||||
pending_chunk = None
|
||||
|
||||
async for chunk in response.content.iter_any():
|
||||
if self._cancel_event.is_set():
|
||||
logger.info("TTS synthesis cancelled")
|
||||
return
|
||||
|
||||
buffer += chunk
|
||||
|
||||
# Yield complete chunks
|
||||
while len(buffer) >= chunk_size:
|
||||
audio_chunk = buffer[:chunk_size]
|
||||
buffer = buffer[chunk_size:]
|
||||
|
||||
# Keep one full chunk buffered so we can always tag the true
|
||||
# last full chunk as final when stream length is an exact multiple.
|
||||
if pending_chunk is not None:
|
||||
yield TTSChunk(
|
||||
audio=pending_chunk,
|
||||
sample_rate=self.sample_rate,
|
||||
is_final=False
|
||||
)
|
||||
pending_chunk = audio_chunk
|
||||
|
||||
# Flush pending chunk(s) and remaining tail.
|
||||
if pending_chunk is not None:
|
||||
if buffer:
|
||||
yield TTSChunk(
|
||||
audio=pending_chunk,
|
||||
sample_rate=self.sample_rate,
|
||||
is_final=False
|
||||
)
|
||||
pending_chunk = None
|
||||
else:
|
||||
yield TTSChunk(
|
||||
audio=pending_chunk,
|
||||
sample_rate=self.sample_rate,
|
||||
is_final=True
|
||||
)
|
||||
pending_chunk = None
|
||||
|
||||
if buffer:
|
||||
yield TTSChunk(
|
||||
audio=buffer,
|
||||
sample_rate=self.sample_rate,
|
||||
is_final=True
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("TTS synthesis cancelled via asyncio")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"TTS synthesis error: {e}")
|
||||
raise
|
||||
|
||||
async def cancel(self) -> None:
|
||||
"""Cancel ongoing synthesis."""
|
||||
self._cancel_event.set()
|
||||
|
||||
|
||||
class StreamingTTSAdapter:
|
||||
"""
|
||||
Adapter for streaming LLM text to TTS with sentence-level chunking.
|
||||
|
||||
This reduces latency by starting TTS as soon as a complete sentence
|
||||
is received from the LLM, rather than waiting for the full response.
|
||||
"""
|
||||
|
||||
# Sentence delimiters
|
||||
SENTENCE_ENDS = {',', '。', '!', '?', '.', '!', '?', '\n'}
|
||||
|
||||
def __init__(self, tts_service: BaseTTSService, transport, session_id: str):
|
||||
self.tts_service = tts_service
|
||||
self.transport = transport
|
||||
self.session_id = session_id
|
||||
self._buffer = ""
|
||||
self._cancel_event = asyncio.Event()
|
||||
self._is_speaking = False
|
||||
|
||||
def _is_non_sentence_period(self, text: str, idx: int) -> bool:
|
||||
"""Check whether '.' should NOT be treated as a sentence delimiter."""
|
||||
if text[idx] != ".":
|
||||
return False
|
||||
|
||||
# Decimal/version segment: 1.2, v1.2.3
|
||||
if idx > 0 and idx < len(text) - 1 and text[idx - 1].isdigit() and text[idx + 1].isdigit():
|
||||
return True
|
||||
|
||||
# Number abbreviations: No.1 / No. 1
|
||||
left_start = idx - 1
|
||||
while left_start >= 0 and text[left_start].isalpha():
|
||||
left_start -= 1
|
||||
left_token = text[left_start + 1:idx].lower()
|
||||
if left_token == "no":
|
||||
j = idx + 1
|
||||
while j < len(text) and text[j].isspace():
|
||||
j += 1
|
||||
if j < len(text) and text[j].isdigit():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def process_text_chunk(self, text_chunk: str) -> None:
|
||||
"""
|
||||
Process a text chunk from LLM and trigger TTS when sentence is complete.
|
||||
|
||||
Args:
|
||||
text_chunk: Text chunk from LLM streaming
|
||||
"""
|
||||
if self._cancel_event.is_set():
|
||||
return
|
||||
|
||||
self._buffer += text_chunk
|
||||
|
||||
# Check for sentence completion
|
||||
while True:
|
||||
split_idx = -1
|
||||
for i, char in enumerate(self._buffer):
|
||||
if char == "." and self._is_non_sentence_period(self._buffer, i):
|
||||
continue
|
||||
if char in self.SENTENCE_ENDS:
|
||||
split_idx = i
|
||||
break
|
||||
if split_idx < 0:
|
||||
break
|
||||
|
||||
end_idx = split_idx + 1
|
||||
while end_idx < len(self._buffer) and self._buffer[end_idx] in self.SENTENCE_ENDS:
|
||||
end_idx += 1
|
||||
|
||||
sentence = self._buffer[:end_idx].strip()
|
||||
self._buffer = self._buffer[end_idx:]
|
||||
|
||||
if sentence and any(ch.isalnum() for ch in sentence):
|
||||
await self._speak_sentence(sentence)
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Flush remaining buffer."""
|
||||
if self._buffer.strip() and not self._cancel_event.is_set():
|
||||
await self._speak_sentence(self._buffer.strip())
|
||||
self._buffer = ""
|
||||
|
||||
async def _speak_sentence(self, text: str) -> None:
|
||||
"""Synthesize and send a sentence."""
|
||||
if not text or self._cancel_event.is_set():
|
||||
return
|
||||
|
||||
self._is_speaking = True
|
||||
|
||||
try:
|
||||
async for chunk in self.tts_service.synthesize_stream(text):
|
||||
if self._cancel_event.is_set():
|
||||
break
|
||||
await self.transport.send_audio(chunk.audio)
|
||||
await asyncio.sleep(0.01) # Prevent flooding
|
||||
except Exception as e:
|
||||
logger.error(f"TTS speak error: {e}")
|
||||
finally:
|
||||
self._is_speaking = False
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancel ongoing speech."""
|
||||
self._cancel_event.set()
|
||||
self._buffer = ""
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset for new turn."""
|
||||
self._cancel_event.clear()
|
||||
self._buffer = ""
|
||||
self._is_speaking = False
|
||||
|
||||
@property
|
||||
def is_speaking(self) -> bool:
|
||||
return self._is_speaking
|
||||
|
||||
|
||||
# Backward-compatible alias
|
||||
SiliconFlowTTSService = OpenAICompatibleTTSService
|
||||
@@ -1,317 +1,8 @@
|
||||
"""SiliconFlow ASR (Automatic Speech Recognition) Service.
|
||||
"""Backward-compatible imports for legacy siliconflow_asr module."""
|
||||
|
||||
Uses the SiliconFlow API for speech-to-text transcription.
|
||||
API: https://docs.siliconflow.cn/cn/api-reference/audio/create-audio-transcriptions
|
||||
"""
|
||||
from services.openai_compatible_asr import OpenAICompatibleASRService
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import wave
|
||||
from typing import AsyncIterator, Optional, Callable, Awaitable
|
||||
from loguru import logger
|
||||
# Backward-compatible alias
|
||||
SiliconFlowASRService = OpenAICompatibleASRService
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
AIOHTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIOHTTP_AVAILABLE = False
|
||||
logger.warning("aiohttp not available - SiliconFlowASRService will not work")
|
||||
|
||||
from services.base import BaseASRService, ASRResult, ServiceState
|
||||
|
||||
|
||||
class SiliconFlowASRService(BaseASRService):
|
||||
"""
|
||||
SiliconFlow ASR service for speech-to-text transcription.
|
||||
|
||||
Features:
|
||||
- Buffers incoming audio chunks
|
||||
- Provides interim transcriptions periodically (for streaming to client)
|
||||
- Final transcription on EOU
|
||||
|
||||
API Details:
|
||||
- Endpoint: POST https://api.siliconflow.cn/v1/audio/transcriptions
|
||||
- Models: FunAudioLLM/SenseVoiceSmall (default), TeleAI/TeleSpeechASR
|
||||
- Input: Audio file (multipart/form-data)
|
||||
- Output: {"text": "transcribed text"}
|
||||
"""
|
||||
|
||||
# Supported models
|
||||
MODELS = {
|
||||
"sensevoice": "FunAudioLLM/SenseVoiceSmall",
|
||||
"telespeech": "TeleAI/TeleSpeechASR",
|
||||
}
|
||||
|
||||
API_URL = "https://api.siliconflow.cn/v1/audio/transcriptions"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str = "FunAudioLLM/SenseVoiceSmall",
|
||||
sample_rate: int = 16000,
|
||||
language: str = "auto",
|
||||
interim_interval_ms: int = 500, # How often to send interim results
|
||||
min_audio_for_interim_ms: int = 300, # Min audio before first interim
|
||||
on_transcript: Optional[Callable[[str, bool], Awaitable[None]]] = None
|
||||
):
|
||||
"""
|
||||
Initialize SiliconFlow ASR service.
|
||||
|
||||
Args:
|
||||
api_key: SiliconFlow API key
|
||||
model: ASR model name or alias
|
||||
sample_rate: Audio sample rate (16000 recommended)
|
||||
language: Language code (auto for automatic detection)
|
||||
interim_interval_ms: How often to generate interim transcriptions
|
||||
min_audio_for_interim_ms: Minimum audio duration before first interim
|
||||
on_transcript: Callback for transcription results (text, is_final)
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, language=language)
|
||||
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
raise RuntimeError("aiohttp is required for SiliconFlowASRService")
|
||||
|
||||
self.api_key = api_key
|
||||
self.model = self.MODELS.get(model.lower(), model)
|
||||
self.interim_interval_ms = interim_interval_ms
|
||||
self.min_audio_for_interim_ms = min_audio_for_interim_ms
|
||||
self.on_transcript = on_transcript
|
||||
|
||||
# Session
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
# Audio buffer
|
||||
self._audio_buffer: bytes = b""
|
||||
self._current_text: str = ""
|
||||
self._last_interim_time: float = 0
|
||||
|
||||
# Transcript queue for async iteration
|
||||
self._transcript_queue: asyncio.Queue[ASRResult] = asyncio.Queue()
|
||||
|
||||
# Background task for interim results
|
||||
self._interim_task: Optional[asyncio.Task] = None
|
||||
self._running = False
|
||||
|
||||
logger.info(f"SiliconFlowASRService initialized with model: {self.model}")
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to the service."""
|
||||
self._session = aiohttp.ClientSession(
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
)
|
||||
self._running = True
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info("SiliconFlowASRService connected")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect and cleanup."""
|
||||
self._running = False
|
||||
|
||||
if self._interim_task:
|
||||
self._interim_task.cancel()
|
||||
try:
|
||||
await self._interim_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._interim_task = None
|
||||
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
self._audio_buffer = b""
|
||||
self._current_text = ""
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("SiliconFlowASRService disconnected")
|
||||
|
||||
async def send_audio(self, audio: bytes) -> None:
|
||||
"""
|
||||
Buffer incoming audio data.
|
||||
|
||||
Args:
|
||||
audio: PCM audio data (16-bit, mono)
|
||||
"""
|
||||
self._audio_buffer += audio
|
||||
|
||||
async def transcribe_buffer(self, is_final: bool = False) -> Optional[str]:
|
||||
"""
|
||||
Transcribe current audio buffer.
|
||||
|
||||
Args:
|
||||
is_final: Whether this is the final transcription
|
||||
|
||||
Returns:
|
||||
Transcribed text or None if not enough audio
|
||||
"""
|
||||
if not self._session:
|
||||
logger.warning("ASR session not connected")
|
||||
return None
|
||||
|
||||
# Check minimum audio duration
|
||||
audio_duration_ms = len(self._audio_buffer) / (self.sample_rate * 2) * 1000
|
||||
|
||||
if not is_final and audio_duration_ms < self.min_audio_for_interim_ms:
|
||||
return None
|
||||
|
||||
if audio_duration_ms < 100: # Less than 100ms - too short
|
||||
return None
|
||||
|
||||
try:
|
||||
# Convert PCM to WAV in memory
|
||||
wav_buffer = io.BytesIO()
|
||||
with wave.open(wav_buffer, 'wb') as wav_file:
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.setsampwidth(2) # 16-bit
|
||||
wav_file.setframerate(self.sample_rate)
|
||||
wav_file.writeframes(self._audio_buffer)
|
||||
|
||||
wav_buffer.seek(0)
|
||||
wav_data = wav_buffer.read()
|
||||
|
||||
# Send to API
|
||||
form_data = aiohttp.FormData()
|
||||
form_data.add_field(
|
||||
'file',
|
||||
wav_data,
|
||||
filename='audio.wav',
|
||||
content_type='audio/wav'
|
||||
)
|
||||
form_data.add_field('model', self.model)
|
||||
|
||||
async with self._session.post(self.API_URL, data=form_data) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
text = result.get("text", "").strip()
|
||||
|
||||
if text:
|
||||
self._current_text = text
|
||||
|
||||
# Notify via callback
|
||||
if self.on_transcript:
|
||||
await self.on_transcript(text, is_final)
|
||||
|
||||
# Queue result
|
||||
await self._transcript_queue.put(
|
||||
ASRResult(text=text, is_final=is_final)
|
||||
)
|
||||
|
||||
logger.debug(f"ASR {'final' if is_final else 'interim'}: {text[:50]}...")
|
||||
return text
|
||||
else:
|
||||
error_text = await response.text()
|
||||
logger.error(f"ASR API error {response.status}: {error_text}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ASR transcription error: {e}")
|
||||
return None
|
||||
|
||||
async def get_final_transcription(self) -> str:
|
||||
"""
|
||||
Get final transcription and clear buffer.
|
||||
|
||||
Call this when EOU is detected.
|
||||
|
||||
Returns:
|
||||
Final transcribed text
|
||||
"""
|
||||
# Transcribe full buffer as final
|
||||
text = await self.transcribe_buffer(is_final=True)
|
||||
|
||||
# Clear buffer
|
||||
result = text or self._current_text
|
||||
self._audio_buffer = b""
|
||||
self._current_text = ""
|
||||
|
||||
return result
|
||||
|
||||
def get_and_clear_text(self) -> str:
|
||||
"""
|
||||
Get accumulated text and clear buffer.
|
||||
|
||||
Compatible with BufferedASRService interface.
|
||||
"""
|
||||
text = self._current_text
|
||||
self._current_text = ""
|
||||
self._audio_buffer = b""
|
||||
return text
|
||||
|
||||
def get_audio_buffer(self) -> bytes:
|
||||
"""Get current audio buffer."""
|
||||
return self._audio_buffer
|
||||
|
||||
def get_audio_duration_ms(self) -> float:
|
||||
"""Get current audio buffer duration in milliseconds."""
|
||||
return len(self._audio_buffer) / (self.sample_rate * 2) * 1000
|
||||
|
||||
def clear_buffer(self) -> None:
|
||||
"""Clear audio and text buffers."""
|
||||
self._audio_buffer = b""
|
||||
self._current_text = ""
|
||||
|
||||
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
|
||||
"""
|
||||
Async iterator for transcription results.
|
||||
|
||||
Yields:
|
||||
ASRResult with text and is_final flag
|
||||
"""
|
||||
while self._running:
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._transcript_queue.get(),
|
||||
timeout=0.1
|
||||
)
|
||||
yield result
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
async def start_interim_transcription(self) -> None:
|
||||
"""
|
||||
Start background task for interim transcriptions.
|
||||
|
||||
This periodically transcribes buffered audio for
|
||||
real-time feedback to the user.
|
||||
"""
|
||||
if self._interim_task and not self._interim_task.done():
|
||||
return
|
||||
|
||||
self._interim_task = asyncio.create_task(self._interim_loop())
|
||||
|
||||
async def stop_interim_transcription(self) -> None:
|
||||
"""Stop interim transcription task."""
|
||||
if self._interim_task:
|
||||
self._interim_task.cancel()
|
||||
try:
|
||||
await self._interim_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._interim_task = None
|
||||
|
||||
async def _interim_loop(self) -> None:
|
||||
"""Background loop for interim transcriptions."""
|
||||
import time
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(self.interim_interval_ms / 1000)
|
||||
|
||||
# Check if we have enough new audio
|
||||
current_time = time.time()
|
||||
time_since_last = (current_time - self._last_interim_time) * 1000
|
||||
|
||||
if time_since_last >= self.interim_interval_ms:
|
||||
audio_duration = self.get_audio_duration_ms()
|
||||
|
||||
if audio_duration >= self.min_audio_for_interim_ms:
|
||||
await self.transcribe_buffer(is_final=False)
|
||||
self._last_interim_time = current_time
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Interim transcription error: {e}")
|
||||
__all__ = ["OpenAICompatibleASRService", "SiliconFlowASRService"]
|
||||
|
||||
@@ -1,311 +1,8 @@
|
||||
"""SiliconFlow TTS Service with streaming support.
|
||||
"""Backward-compatible imports for legacy siliconflow_tts module."""
|
||||
|
||||
Uses SiliconFlow's CosyVoice2 or MOSS-TTSD models for low-latency
|
||||
text-to-speech synthesis with streaming.
|
||||
from services.openai_compatible_tts import OpenAICompatibleTTSService, StreamingTTSAdapter
|
||||
|
||||
API Docs: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech
|
||||
"""
|
||||
# Backward-compatible alias
|
||||
SiliconFlowTTSService = OpenAICompatibleTTSService
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from typing import AsyncIterator, Optional
|
||||
from loguru import logger
|
||||
|
||||
from services.base import BaseTTSService, TTSChunk, ServiceState
|
||||
from services.streaming_tts_adapter import StreamingTTSAdapter # backward-compatible re-export
|
||||
|
||||
|
||||
class SiliconFlowTTSService(BaseTTSService):
|
||||
"""
|
||||
SiliconFlow TTS service with streaming support.
|
||||
|
||||
Supports CosyVoice2-0.5B and MOSS-TTSD-v0.5 models.
|
||||
"""
|
||||
|
||||
# Available voices
|
||||
VOICES = {
|
||||
"alex": "FunAudioLLM/CosyVoice2-0.5B:alex",
|
||||
"anna": "FunAudioLLM/CosyVoice2-0.5B:anna",
|
||||
"bella": "FunAudioLLM/CosyVoice2-0.5B:bella",
|
||||
"benjamin": "FunAudioLLM/CosyVoice2-0.5B:benjamin",
|
||||
"charles": "FunAudioLLM/CosyVoice2-0.5B:charles",
|
||||
"claire": "FunAudioLLM/CosyVoice2-0.5B:claire",
|
||||
"david": "FunAudioLLM/CosyVoice2-0.5B:david",
|
||||
"diana": "FunAudioLLM/CosyVoice2-0.5B:diana",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
voice: str = "anna",
|
||||
model: str = "FunAudioLLM/CosyVoice2-0.5B",
|
||||
sample_rate: int = 16000,
|
||||
speed: float = 1.0
|
||||
):
|
||||
"""
|
||||
Initialize SiliconFlow TTS service.
|
||||
|
||||
Args:
|
||||
api_key: SiliconFlow API key (defaults to SILICONFLOW_API_KEY env var)
|
||||
voice: Voice name (alex, anna, bella, benjamin, charles, claire, david, diana)
|
||||
model: Model name
|
||||
sample_rate: Output sample rate (8000, 16000, 24000, 32000, 44100)
|
||||
speed: Speech speed (0.25 to 4.0)
|
||||
"""
|
||||
# Resolve voice name
|
||||
if voice in self.VOICES:
|
||||
full_voice = self.VOICES[voice]
|
||||
else:
|
||||
full_voice = voice
|
||||
|
||||
super().__init__(voice=full_voice, sample_rate=sample_rate, speed=speed)
|
||||
|
||||
self.api_key = api_key or os.getenv("SILICONFLOW_API_KEY")
|
||||
self.model = model
|
||||
self.api_url = "https://api.siliconflow.cn/v1/audio/speech"
|
||||
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._cancel_event = asyncio.Event()
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize HTTP session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("SiliconFlow API key not provided. Set SILICONFLOW_API_KEY env var.")
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
)
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info(f"SiliconFlow TTS service ready: voice={self.voice}, model={self.model}")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Close HTTP session."""
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("SiliconFlow TTS service disconnected")
|
||||
|
||||
async def synthesize(self, text: str) -> bytes:
|
||||
"""Synthesize complete audio for text."""
|
||||
audio_data = b""
|
||||
async for chunk in self.synthesize_stream(text):
|
||||
audio_data += chunk.audio
|
||||
return audio_data
|
||||
|
||||
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
|
||||
"""
|
||||
Synthesize audio in streaming mode.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
|
||||
Yields:
|
||||
TTSChunk objects with PCM audio
|
||||
"""
|
||||
if not self._session:
|
||||
raise RuntimeError("TTS service not connected")
|
||||
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
self._cancel_event.clear()
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"input": text,
|
||||
"voice": self.voice,
|
||||
"response_format": "pcm",
|
||||
"sample_rate": self.sample_rate,
|
||||
"stream": True,
|
||||
"speed": self.speed
|
||||
}
|
||||
|
||||
try:
|
||||
async with self._session.post(self.api_url, json=payload) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"SiliconFlow TTS error: {response.status} - {error_text}")
|
||||
return
|
||||
|
||||
# Stream audio chunks
|
||||
chunk_size = self.sample_rate * 2 // 10 # 100ms chunks
|
||||
buffer = b""
|
||||
pending_chunk = None
|
||||
|
||||
async for chunk in response.content.iter_any():
|
||||
if self._cancel_event.is_set():
|
||||
logger.info("TTS synthesis cancelled")
|
||||
return
|
||||
|
||||
buffer += chunk
|
||||
|
||||
# Yield complete chunks
|
||||
while len(buffer) >= chunk_size:
|
||||
audio_chunk = buffer[:chunk_size]
|
||||
buffer = buffer[chunk_size:]
|
||||
|
||||
# Keep one full chunk buffered so we can always tag the true
|
||||
# last full chunk as final when stream length is an exact multiple.
|
||||
if pending_chunk is not None:
|
||||
yield TTSChunk(
|
||||
audio=pending_chunk,
|
||||
sample_rate=self.sample_rate,
|
||||
is_final=False
|
||||
)
|
||||
pending_chunk = audio_chunk
|
||||
|
||||
# Flush pending chunk(s) and remaining tail.
|
||||
if pending_chunk is not None:
|
||||
if buffer:
|
||||
yield TTSChunk(
|
||||
audio=pending_chunk,
|
||||
sample_rate=self.sample_rate,
|
||||
is_final=False
|
||||
)
|
||||
pending_chunk = None
|
||||
else:
|
||||
yield TTSChunk(
|
||||
audio=pending_chunk,
|
||||
sample_rate=self.sample_rate,
|
||||
is_final=True
|
||||
)
|
||||
pending_chunk = None
|
||||
|
||||
if buffer:
|
||||
yield TTSChunk(
|
||||
audio=buffer,
|
||||
sample_rate=self.sample_rate,
|
||||
is_final=True
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("TTS synthesis cancelled via asyncio")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"TTS synthesis error: {e}")
|
||||
raise
|
||||
|
||||
async def cancel(self) -> None:
|
||||
"""Cancel ongoing synthesis."""
|
||||
self._cancel_event.set()
|
||||
|
||||
|
||||
class StreamingTTSAdapter:
|
||||
"""
|
||||
Adapter for streaming LLM text to TTS with sentence-level chunking.
|
||||
|
||||
This reduces latency by starting TTS as soon as a complete sentence
|
||||
is received from the LLM, rather than waiting for the full response.
|
||||
"""
|
||||
|
||||
# Sentence delimiters
|
||||
SENTENCE_ENDS = {',', '。', '!', '?', '.', '!', '?', '\n'}
|
||||
|
||||
def __init__(self, tts_service: BaseTTSService, transport, session_id: str):
|
||||
self.tts_service = tts_service
|
||||
self.transport = transport
|
||||
self.session_id = session_id
|
||||
self._buffer = ""
|
||||
self._cancel_event = asyncio.Event()
|
||||
self._is_speaking = False
|
||||
|
||||
def _is_non_sentence_period(self, text: str, idx: int) -> bool:
|
||||
"""Check whether '.' should NOT be treated as a sentence delimiter."""
|
||||
if text[idx] != ".":
|
||||
return False
|
||||
|
||||
# Decimal/version segment: 1.2, v1.2.3
|
||||
if idx > 0 and idx < len(text) - 1 and text[idx - 1].isdigit() and text[idx + 1].isdigit():
|
||||
return True
|
||||
|
||||
# Number abbreviations: No.1 / No. 1
|
||||
left_start = idx - 1
|
||||
while left_start >= 0 and text[left_start].isalpha():
|
||||
left_start -= 1
|
||||
left_token = text[left_start + 1:idx].lower()
|
||||
if left_token == "no":
|
||||
j = idx + 1
|
||||
while j < len(text) and text[j].isspace():
|
||||
j += 1
|
||||
if j < len(text) and text[j].isdigit():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def process_text_chunk(self, text_chunk: str) -> None:
|
||||
"""
|
||||
Process a text chunk from LLM and trigger TTS when sentence is complete.
|
||||
|
||||
Args:
|
||||
text_chunk: Text chunk from LLM streaming
|
||||
"""
|
||||
if self._cancel_event.is_set():
|
||||
return
|
||||
|
||||
self._buffer += text_chunk
|
||||
|
||||
# Check for sentence completion
|
||||
while True:
|
||||
split_idx = -1
|
||||
for i, char in enumerate(self._buffer):
|
||||
if char == "." and self._is_non_sentence_period(self._buffer, i):
|
||||
continue
|
||||
if char in self.SENTENCE_ENDS:
|
||||
split_idx = i
|
||||
break
|
||||
if split_idx < 0:
|
||||
break
|
||||
|
||||
end_idx = split_idx + 1
|
||||
while end_idx < len(self._buffer) and self._buffer[end_idx] in self.SENTENCE_ENDS:
|
||||
end_idx += 1
|
||||
|
||||
sentence = self._buffer[:end_idx].strip()
|
||||
self._buffer = self._buffer[end_idx:]
|
||||
|
||||
if sentence and any(ch.isalnum() for ch in sentence):
|
||||
await self._speak_sentence(sentence)
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Flush remaining buffer."""
|
||||
if self._buffer.strip() and not self._cancel_event.is_set():
|
||||
await self._speak_sentence(self._buffer.strip())
|
||||
self._buffer = ""
|
||||
|
||||
async def _speak_sentence(self, text: str) -> None:
|
||||
"""Synthesize and send a sentence."""
|
||||
if not text or self._cancel_event.is_set():
|
||||
return
|
||||
|
||||
self._is_speaking = True
|
||||
|
||||
try:
|
||||
async for chunk in self.tts_service.synthesize_stream(text):
|
||||
if self._cancel_event.is_set():
|
||||
break
|
||||
await self.transport.send_audio(chunk.audio)
|
||||
await asyncio.sleep(0.01) # Prevent flooding
|
||||
except Exception as e:
|
||||
logger.error(f"TTS speak error: {e}")
|
||||
finally:
|
||||
self._is_speaking = False
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancel ongoing speech."""
|
||||
self._cancel_event.set()
|
||||
self._buffer = ""
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset for new turn."""
|
||||
self._cancel_event.clear()
|
||||
self._buffer = ""
|
||||
self._is_speaking = False
|
||||
|
||||
@property
|
||||
def is_speaking(self) -> bool:
|
||||
return self._is_speaking
|
||||
__all__ = ["OpenAICompatibleTTSService", "SiliconFlowTTSService", "StreamingTTSAdapter"]
|
||||
|
||||
@@ -85,7 +85,7 @@ const convertRecordedBlobToWav = async (blob: Blob): Promise<File> => {
|
||||
export const ASRLibraryPage: React.FC = () => {
|
||||
const [models, setModels] = useState<ASRModel[]>([]);
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
const [vendorFilter, setVendorFilter] = useState<string>('all');
|
||||
const [vendorFilter, setVendorFilter] = useState<string>('OpenAI Compatible');
|
||||
const [langFilter, setLangFilter] = useState<string>('all');
|
||||
const [isAddModalOpen, setIsAddModalOpen] = useState(false);
|
||||
const [editingModel, setEditingModel] = useState<ASRModel | null>(null);
|
||||
@@ -111,7 +111,7 @@ export const ASRLibraryPage: React.FC = () => {
|
||||
const filteredModels = models.filter((m) => {
|
||||
const q = searchTerm.toLowerCase();
|
||||
const matchesSearch = m.name.toLowerCase().includes(q) || (m.modelName || '').toLowerCase().includes(q);
|
||||
const matchesVendor = vendorFilter === 'all' || m.vendor === vendorFilter;
|
||||
const matchesVendor = m.vendor === vendorFilter;
|
||||
const matchesLang = langFilter === 'all' || m.language === langFilter || (langFilter !== 'all' && m.language === 'Multi-lingual');
|
||||
return matchesSearch && matchesVendor && matchesLang;
|
||||
});
|
||||
@@ -134,8 +134,6 @@ export const ASRLibraryPage: React.FC = () => {
|
||||
setModels((prev) => prev.filter((m) => m.id !== id));
|
||||
};
|
||||
|
||||
const vendorOptions = Array.from(new Set(models.map((m) => m.vendor).filter(Boolean)));
|
||||
|
||||
return (
|
||||
<div className="space-y-6 animate-in fade-in py-4 pb-10">
|
||||
<div className="flex items-center justify-between">
|
||||
@@ -162,10 +160,7 @@ export const ASRLibraryPage: React.FC = () => {
|
||||
value={vendorFilter}
|
||||
onChange={(e) => setVendorFilter(e.target.value)}
|
||||
>
|
||||
<option value="all">所有厂商</option>
|
||||
{vendorOptions.map((vendor) => (
|
||||
<option key={vendor} value={vendor}>{vendor}</option>
|
||||
))}
|
||||
<option value="OpenAI Compatible">OpenAI Compatible</option>
|
||||
</select>
|
||||
</div>
|
||||
<div className="flex items-center space-x-2">
|
||||
@@ -371,7 +366,6 @@ const ASRModelModal: React.FC<{
|
||||
onChange={(e) => setVendor(e.target.value)}
|
||||
>
|
||||
<option value="OpenAI Compatible">OpenAI Compatible</option>
|
||||
<option value="SiliconFlow">SiliconFlow</option>
|
||||
</select>
|
||||
</div>
|
||||
<div className="space-y-1.5">
|
||||
|
||||
@@ -5,32 +5,37 @@ import { Button, Input, Badge, Drawer, Dialog } from '../components/UI';
|
||||
import { ASRModel, Assistant, KnowledgeBase, LLMModel, TabValue, Tool, Voice } from '../types';
|
||||
import { createAssistant, deleteAssistant, fetchASRModels, fetchAssistants, fetchKnowledgeBases, fetchLLMModels, fetchTools, fetchVoices, updateAssistant as updateAssistantApi } from '../services/backendApi';
|
||||
|
||||
const isSiliconflowVendor = (vendor?: string) => {
|
||||
const isOpenAICompatibleVendor = (vendor?: string) => {
|
||||
const normalized = String(vendor || '').trim().toLowerCase();
|
||||
return normalized === 'siliconflow' || normalized === '硅基流动';
|
||||
return (
|
||||
normalized === 'siliconflow' ||
|
||||
normalized === '硅基流动' ||
|
||||
normalized === 'openai compatible' ||
|
||||
normalized === 'openai-compatible'
|
||||
);
|
||||
};
|
||||
|
||||
const SILICONFLOW_DEFAULT_MODEL = 'FunAudioLLM/CosyVoice2-0.5B';
|
||||
const OPENAI_COMPATIBLE_DEFAULT_MODEL = 'FunAudioLLM/CosyVoice2-0.5B';
|
||||
|
||||
const buildSiliconflowVoiceKey = (voiceId: string, model?: string) => {
|
||||
const buildOpenAICompatibleVoiceKey = (voiceId: string, model?: string) => {
|
||||
const id = String(voiceId || '').trim();
|
||||
if (!id) return '';
|
||||
if (id.includes(':')) return id;
|
||||
return `${model || SILICONFLOW_DEFAULT_MODEL}:${id}`;
|
||||
return `${model || OPENAI_COMPATIBLE_DEFAULT_MODEL}:${id}`;
|
||||
};
|
||||
|
||||
const resolveRuntimeTtsVoice = (selectedVoiceId: string, voice: Voice) => {
|
||||
const explicitKey = String(voice.voiceKey || '').trim();
|
||||
if (!isSiliconflowVendor(voice.vendor)) {
|
||||
if (!isOpenAICompatibleVendor(voice.vendor)) {
|
||||
return explicitKey || selectedVoiceId;
|
||||
}
|
||||
if (voice.isSystem) {
|
||||
const canonical = buildSiliconflowVoiceKey(selectedVoiceId, voice.model);
|
||||
const canonical = buildOpenAICompatibleVoiceKey(selectedVoiceId, voice.model);
|
||||
if (!explicitKey) return canonical;
|
||||
const explicitSuffix = explicitKey.includes(':') ? explicitKey.split(':').pop() : explicitKey;
|
||||
if (explicitSuffix && explicitSuffix !== selectedVoiceId) return canonical;
|
||||
}
|
||||
return explicitKey || buildSiliconflowVoiceKey(selectedVoiceId, voice.model);
|
||||
return explicitKey || buildOpenAICompatibleVoiceKey(selectedVoiceId, voice.model);
|
||||
};
|
||||
|
||||
const renderToolIcon = (icon: string) => {
|
||||
@@ -1830,11 +1835,11 @@ export const DebugDrawer: React.FC<{
|
||||
if (assistant.asrModelId) {
|
||||
const asr = asrModels.find((item) => item.id === assistant.asrModelId);
|
||||
if (asr) {
|
||||
const asrProvider = isSiliconflowVendor(asr.vendor) ? 'siliconflow' : 'buffered';
|
||||
const asrProvider = isOpenAICompatibleVendor(asr.vendor) ? 'openai_compatible' : 'buffered';
|
||||
services.asr = {
|
||||
provider: asrProvider,
|
||||
model: asr.modelName || asr.name,
|
||||
apiKey: asrProvider === 'siliconflow' ? asr.apiKey : null,
|
||||
apiKey: asrProvider === 'openai_compatible' ? asr.apiKey : null,
|
||||
};
|
||||
} else {
|
||||
warnings.push(`ASR model not found in loaded list: ${assistant.asrModelId}`);
|
||||
@@ -1844,12 +1849,12 @@ export const DebugDrawer: React.FC<{
|
||||
if (assistant.voice) {
|
||||
const voice = voices.find((item) => item.id === assistant.voice);
|
||||
if (voice) {
|
||||
const ttsProvider = isSiliconflowVendor(voice.vendor) ? 'siliconflow' : 'edge';
|
||||
const ttsProvider = isOpenAICompatibleVendor(voice.vendor) ? 'openai_compatible' : 'edge';
|
||||
services.tts = {
|
||||
enabled: ttsEnabled,
|
||||
provider: ttsProvider,
|
||||
model: voice.model,
|
||||
apiKey: ttsProvider === 'siliconflow' ? voice.apiKey : null,
|
||||
apiKey: ttsProvider === 'openai_compatible' ? voice.apiKey : null,
|
||||
voice: resolveRuntimeTtsVoice(assistant.voice, voice),
|
||||
speed: assistant.speed || voice.speed || 1.0,
|
||||
};
|
||||
|
||||
@@ -13,7 +13,7 @@ const maskApiKey = (key?: string) => {
|
||||
export const LLMLibraryPage: React.FC = () => {
|
||||
const [models, setModels] = useState<LLMModel[]>([]);
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
const [vendorFilter, setVendorFilter] = useState<string>('all');
|
||||
const [vendorFilter, setVendorFilter] = useState<string>('OpenAI Compatible');
|
||||
const [typeFilter, setTypeFilter] = useState<string>('all');
|
||||
const [isAddModalOpen, setIsAddModalOpen] = useState(false);
|
||||
const [editingModel, setEditingModel] = useState<LLMModel | null>(null);
|
||||
@@ -41,7 +41,7 @@ export const LLMLibraryPage: React.FC = () => {
|
||||
m.name.toLowerCase().includes(q) ||
|
||||
(m.modelName || '').toLowerCase().includes(q) ||
|
||||
(m.baseUrl || '').toLowerCase().includes(q);
|
||||
const matchesVendor = vendorFilter === 'all' || m.vendor === vendorFilter;
|
||||
const matchesVendor = m.vendor === vendorFilter;
|
||||
const matchesType = typeFilter === 'all' || m.type === typeFilter;
|
||||
return matchesSearch && matchesVendor && matchesType;
|
||||
});
|
||||
@@ -64,8 +64,6 @@ export const LLMLibraryPage: React.FC = () => {
|
||||
setModels((prev) => prev.filter((item) => item.id !== id));
|
||||
};
|
||||
|
||||
const vendorOptions = Array.from(new Set(models.map((m) => m.vendor).filter(Boolean)));
|
||||
|
||||
return (
|
||||
<div className="space-y-6 animate-in fade-in py-4 pb-10">
|
||||
<div className="flex items-center justify-between">
|
||||
@@ -92,10 +90,7 @@ export const LLMLibraryPage: React.FC = () => {
|
||||
value={vendorFilter}
|
||||
onChange={(e) => setVendorFilter(e.target.value)}
|
||||
>
|
||||
<option value="all">所有厂商</option>
|
||||
{vendorOptions.map((vendor) => (
|
||||
<option key={vendor} value={vendor}>{vendor}</option>
|
||||
))}
|
||||
<option value="OpenAI Compatible">OpenAI Compatible</option>
|
||||
</select>
|
||||
</div>
|
||||
<div className="flex items-center space-x-2">
|
||||
@@ -284,8 +279,6 @@ const LLMModelModal: React.FC<{
|
||||
onChange={(e) => setVendor(e.target.value)}
|
||||
>
|
||||
<option value="OpenAI Compatible">OpenAI Compatible</option>
|
||||
<option value="OpenAI">OpenAI</option>
|
||||
<option value="SiliconFlow">SiliconFlow</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import React, { useEffect, useState, useRef } from 'react';
|
||||
import { Search, Mic2, Play, Pause, Upload, Filter, Plus, Volume2, Sparkles, ChevronDown, Pencil, Trash2 } from 'lucide-react';
|
||||
import { Search, Mic2, Play, Pause, Upload, Filter, Plus, Volume2, Pencil, Trash2 } from 'lucide-react';
|
||||
import { Button, Input, TableHeader, TableRow, TableHead, TableCell, Dialog, Badge } from '../components/UI';
|
||||
import { Voice } from '../types';
|
||||
import { createVoice, deleteVoice, fetchVoices, previewVoice, updateVoice } from '../services/backendApi';
|
||||
|
||||
const SILICONFLOW_DEFAULT_MODEL = 'FunAudioLLM/CosyVoice2-0.5B';
|
||||
const OPENAI_COMPATIBLE_DEFAULT_MODEL = 'FunAudioLLM/CosyVoice2-0.5B';
|
||||
|
||||
const buildSiliconflowVoiceKey = (rawId: string, model: string): string => {
|
||||
const buildOpenAICompatibleVoiceKey = (rawId: string, model: string): string => {
|
||||
const id = (rawId || '').trim();
|
||||
if (!id) return `${model}:anna`;
|
||||
return id.includes(':') ? id : `${model}:${id}`;
|
||||
@@ -15,7 +15,7 @@ const buildSiliconflowVoiceKey = (rawId: string, model: string): string => {
|
||||
export const VoiceLibraryPage: React.FC = () => {
|
||||
const [voices, setVoices] = useState<Voice[]>([]);
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
const [vendorFilter, setVendorFilter] = useState<'all' | 'Ali' | 'Volcano' | 'Minimax' | '硅基流动' | 'SiliconFlow'>('all');
|
||||
const [vendorFilter, setVendorFilter] = useState<'OpenAI Compatible'>('OpenAI Compatible');
|
||||
const [genderFilter, setGenderFilter] = useState<'all' | 'Male' | 'Female'>('all');
|
||||
const [langFilter, setLangFilter] = useState<'all' | 'zh' | 'en'>('all');
|
||||
|
||||
@@ -44,7 +44,7 @@ export const VoiceLibraryPage: React.FC = () => {
|
||||
|
||||
const filteredVoices = voices.filter((voice) => {
|
||||
const matchesSearch = voice.name.toLowerCase().includes(searchTerm.toLowerCase());
|
||||
const matchesVendor = vendorFilter === 'all' || voice.vendor === vendorFilter;
|
||||
const matchesVendor = voice.vendor === vendorFilter;
|
||||
const matchesGender = genderFilter === 'all' || voice.gender === genderFilter;
|
||||
const matchesLang = langFilter === 'all' || voice.language === langFilter;
|
||||
return matchesSearch && matchesVendor && matchesGender && matchesLang;
|
||||
@@ -138,12 +138,7 @@ export const VoiceLibraryPage: React.FC = () => {
|
||||
value={vendorFilter}
|
||||
onChange={(e) => setVendorFilter(e.target.value as any)}
|
||||
>
|
||||
<option value="all">所有厂商</option>
|
||||
<option value="硅基流动">硅基流动 (SiliconFlow)</option>
|
||||
<option value="SiliconFlow">SiliconFlow</option>
|
||||
<option value="Ali">阿里 (Ali)</option>
|
||||
<option value="Volcano">火山 (Volcano)</option>
|
||||
<option value="Minimax">Minimax</option>
|
||||
<option value="OpenAI Compatible">OpenAI Compatible</option>
|
||||
</select>
|
||||
</div>
|
||||
<div className="flex items-center space-x-2">
|
||||
@@ -187,15 +182,12 @@ export const VoiceLibraryPage: React.FC = () => {
|
||||
<TableRow key={voice.id}>
|
||||
<TableCell className="font-medium">
|
||||
<div className="flex flex-col">
|
||||
<span className="flex items-center text-white">
|
||||
{voice.vendor === '硅基流动' && <Sparkles className="w-3 h-3 text-primary mr-1.5" />}
|
||||
{voice.name}
|
||||
</span>
|
||||
<span className="flex items-center text-white">{voice.name}</span>
|
||||
{voice.description && <span className="text-xs text-muted-foreground">{voice.description}</span>}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<Badge variant={voice.vendor === '硅基流动' ? 'default' : 'outline'}>{voice.vendor}</Badge>
|
||||
<Badge variant="outline">{voice.vendor}</Badge>
|
||||
</TableCell>
|
||||
<TableCell className="text-muted-foreground">{voice.gender === 'Male' ? '男' : '女'}</TableCell>
|
||||
<TableCell className="text-muted-foreground">{voice.language === 'zh' ? '中文' : 'English'}</TableCell>
|
||||
@@ -254,17 +246,15 @@ const AddVoiceModal: React.FC<{
|
||||
onSuccess: (voice: Voice) => Promise<void>;
|
||||
initialVoice?: Voice;
|
||||
}> = ({ isOpen, onClose, onSuccess, initialVoice }) => {
|
||||
const [vendor, setVendor] = useState<'硅基流动' | 'Ali' | 'Volcano' | 'Minimax'>('硅基流动');
|
||||
const [vendor, setVendor] = useState<'OpenAI Compatible'>('OpenAI Compatible');
|
||||
const [name, setName] = useState('');
|
||||
|
||||
const [sfModel, setSfModel] = useState(SILICONFLOW_DEFAULT_MODEL);
|
||||
const [openaiCompatibleModel, setOpenaiCompatibleModel] = useState(OPENAI_COMPATIBLE_DEFAULT_MODEL);
|
||||
const [sfVoiceId, setSfVoiceId] = useState('FunAudioLLM/CosyVoice2-0.5B:anna');
|
||||
const [sfSpeed, setSfSpeed] = useState(1);
|
||||
const [sfGain, setSfGain] = useState(0);
|
||||
const [sfPitch, setSfPitch] = useState(0);
|
||||
|
||||
const [model, setModel] = useState('');
|
||||
const [voiceKey, setVoiceKey] = useState('');
|
||||
const [gender, setGender] = useState('Female');
|
||||
const [language, setLanguage] = useState('zh');
|
||||
const [description, setDescription] = useState('');
|
||||
@@ -278,17 +268,15 @@ const AddVoiceModal: React.FC<{
|
||||
|
||||
useEffect(() => {
|
||||
if (!initialVoice) return;
|
||||
const nextVendor = initialVoice.vendor === 'SiliconFlow' ? '硅基流动' : initialVoice.vendor;
|
||||
const nextModel = initialVoice.model || SILICONFLOW_DEFAULT_MODEL;
|
||||
const defaultVoiceKey = buildSiliconflowVoiceKey(initialVoice.id || initialVoice.name || '', nextModel);
|
||||
setVendor((nextVendor as any) || '硅基流动');
|
||||
const nextVendor = 'OpenAI Compatible';
|
||||
const nextModel = initialVoice.model || OPENAI_COMPATIBLE_DEFAULT_MODEL;
|
||||
const defaultVoiceKey = buildOpenAICompatibleVoiceKey(initialVoice.id || initialVoice.name || '', nextModel);
|
||||
setVendor(nextVendor);
|
||||
setName(initialVoice.name || '');
|
||||
setGender(initialVoice.gender || 'Female');
|
||||
setLanguage(initialVoice.language || 'zh');
|
||||
setDescription(initialVoice.description || '');
|
||||
setModel(initialVoice.model || '');
|
||||
setVoiceKey(initialVoice.voiceKey || '');
|
||||
setSfModel(nextModel);
|
||||
setOpenaiCompatibleModel(nextModel);
|
||||
setSfVoiceId((initialVoice.voiceKey || '').trim() || defaultVoiceKey);
|
||||
setSfSpeed(initialVoice.speed ?? 1);
|
||||
setSfGain(initialVoice.gain ?? 0);
|
||||
@@ -325,21 +313,21 @@ const AddVoiceModal: React.FC<{
|
||||
return;
|
||||
}
|
||||
|
||||
const resolvedSiliconflowVoiceKey = (() => {
|
||||
const resolvedVoiceKey = (() => {
|
||||
const current = (sfVoiceId || '').trim();
|
||||
if (current) return current;
|
||||
return buildSiliconflowVoiceKey(initialVoice?.id || name, sfModel || SILICONFLOW_DEFAULT_MODEL);
|
||||
return buildOpenAICompatibleVoiceKey(initialVoice?.id || name, openaiCompatibleModel || OPENAI_COMPATIBLE_DEFAULT_MODEL);
|
||||
})();
|
||||
|
||||
const newVoice: Voice = {
|
||||
id: initialVoice?.id || `${vendor === '硅基流动' ? 'sf' : 'gen'}-${Date.now()}`,
|
||||
id: initialVoice?.id || `oa-${Date.now()}`,
|
||||
name,
|
||||
vendor,
|
||||
gender,
|
||||
language,
|
||||
description: description || (vendor === '硅基流动' ? `Model: ${sfModel}` : `Model: ${model}`),
|
||||
model: vendor === '硅基流动' ? sfModel : model,
|
||||
voiceKey: vendor === '硅基流动' ? resolvedSiliconflowVoiceKey : voiceKey,
|
||||
description: description || `Model: ${openaiCompatibleModel}`,
|
||||
model: openaiCompatibleModel,
|
||||
voiceKey: resolvedVoiceKey,
|
||||
apiKey,
|
||||
baseUrl,
|
||||
speed: sfSpeed,
|
||||
@@ -351,10 +339,8 @@ const AddVoiceModal: React.FC<{
|
||||
setIsSaving(true);
|
||||
await onSuccess(newVoice);
|
||||
setName('');
|
||||
setVendor('硅基流动');
|
||||
setVendor('OpenAI Compatible');
|
||||
setDescription('');
|
||||
setModel('');
|
||||
setVoiceKey('');
|
||||
setApiKey('');
|
||||
setBaseUrl('');
|
||||
} catch (error: any) {
|
||||
@@ -381,19 +367,7 @@ const AddVoiceModal: React.FC<{
|
||||
<div className="space-y-4 max-h-[75vh] overflow-y-auto px-1 custom-scrollbar">
|
||||
<div className="space-y-1.5">
|
||||
<label className="text-[10px] font-black text-muted-foreground uppercase tracking-widest block">厂商 (Vendor)</label>
|
||||
<div className="relative">
|
||||
<select
|
||||
className="flex h-10 w-full rounded-md border border-white/10 bg-white/5 px-3 py-1 text-sm shadow-sm transition-colors focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-primary/50 text-foreground appearance-none cursor-pointer [&>option]:bg-card"
|
||||
value={vendor}
|
||||
onChange={(e) => setVendor(e.target.value as any)}
|
||||
>
|
||||
<option value="硅基流动">硅基流动 (SiliconFlow)</option>
|
||||
<option value="Ali">阿里 (Ali)</option>
|
||||
<option value="Volcano">火山 (Volcano)</option>
|
||||
<option value="Minimax">Minimax</option>
|
||||
</select>
|
||||
<ChevronDown className="absolute right-3 top-1/2 -translate-y-1/2 h-4 w-4 text-muted-foreground pointer-events-none" />
|
||||
</div>
|
||||
<Input value={vendor} readOnly className="h-10 border border-white/10 bg-white/5" />
|
||||
</div>
|
||||
|
||||
<div className="h-px bg-white/5"></div>
|
||||
@@ -403,15 +377,14 @@ const AddVoiceModal: React.FC<{
|
||||
<Input value={name} onChange={(e) => setName(e.target.value)} placeholder="例如: 客服小美" />
|
||||
</div>
|
||||
|
||||
{vendor === '硅基流动' ? (
|
||||
<div className="space-y-4 animate-in fade-in slide-in-from-top-1 duration-200">
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div className="space-y-1.5">
|
||||
<label className="text-[10px] font-black text-muted-foreground uppercase tracking-widest block">模型 (Model)</label>
|
||||
<Input
|
||||
className="font-mono text-xs"
|
||||
value={sfModel}
|
||||
onChange={(e) => setSfModel(e.target.value)}
|
||||
value={openaiCompatibleModel}
|
||||
onChange={(e) => setOpenaiCompatibleModel(e.target.value)}
|
||||
placeholder="例如: FunAudioLLM/CosyVoice2-0.5B"
|
||||
/>
|
||||
</div>
|
||||
@@ -445,20 +418,6 @@ const AddVoiceModal: React.FC<{
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-4 animate-in fade-in slide-in-from-top-1 duration-200">
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div className="space-y-1.5">
|
||||
<label className="text-[10px] font-black text-muted-foreground uppercase tracking-widest block">模型标识</label>
|
||||
<Input value={model} onChange={(e) => setModel(e.target.value)} placeholder="API Model Key" />
|
||||
</div>
|
||||
<div className="space-y-1.5">
|
||||
<label className="text-[10px] font-black text-muted-foreground uppercase tracking-widest block">发音人标识</label>
|
||||
<Input value={voiceKey} onChange={(e) => setVoiceKey(e.target.value)} placeholder="Voice Key" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div className="space-y-1.5">
|
||||
@@ -560,7 +519,7 @@ const CloneVoiceModal: React.FC<{
|
||||
const newVoice: Voice = {
|
||||
id: `v-${Date.now()}`,
|
||||
name,
|
||||
vendor: 'Volcano',
|
||||
vendor: 'OpenAI Compatible',
|
||||
gender: 'Female',
|
||||
language: 'zh',
|
||||
description: description || 'User cloned voice',
|
||||
|
||||
@@ -55,8 +55,11 @@ const mapVoice = (raw: AnyRecord): Voice => ({
|
||||
id: String(readField(raw, ['id'], '')),
|
||||
name: readField(raw, ['name'], ''),
|
||||
vendor: ((): string => {
|
||||
const vendor = String(readField(raw, ['vendor'], ''));
|
||||
return vendor.toLowerCase() === 'siliconflow' ? '硅基流动' : vendor;
|
||||
const vendor = String(readField(raw, ['vendor'], '')).trim().toLowerCase();
|
||||
if (vendor === 'siliconflow' || vendor === '硅基流动' || vendor === 'openai-compatible') {
|
||||
return 'OpenAI Compatible';
|
||||
}
|
||||
return String(readField(raw, ['vendor'], 'OpenAI Compatible')) || 'OpenAI Compatible';
|
||||
})(),
|
||||
gender: readField(raw, ['gender'], ''),
|
||||
language: readField(raw, ['language'], ''),
|
||||
@@ -296,7 +299,7 @@ export const createVoice = async (data: Partial<Voice>): Promise<Voice> => {
|
||||
const payload = {
|
||||
id: data.id || undefined,
|
||||
name: data.name || 'New Voice',
|
||||
vendor: data.vendor === '硅基流动' ? 'SiliconFlow' : (data.vendor || 'SiliconFlow'),
|
||||
vendor: data.vendor || 'OpenAI Compatible',
|
||||
gender: data.gender || 'Female',
|
||||
language: data.language || 'zh',
|
||||
description: data.description || '',
|
||||
@@ -316,7 +319,7 @@ export const createVoice = async (data: Partial<Voice>): Promise<Voice> => {
|
||||
export const updateVoice = async (id: string, data: Partial<Voice>): Promise<Voice> => {
|
||||
const payload = {
|
||||
name: data.name,
|
||||
vendor: data.vendor === '硅基流动' ? 'SiliconFlow' : data.vendor,
|
||||
vendor: data.vendor,
|
||||
gender: data.gender,
|
||||
language: data.language,
|
||||
description: data.description,
|
||||
|
||||
@@ -200,7 +200,7 @@ export const mockLLMModels: LLMModel[] = [
|
||||
{ id: 'm1', name: 'GPT-4o', vendor: 'OpenAI Compatible', type: 'text', baseUrl: 'https://api.openai.com/v1', apiKey: 'sk-***', temperature: 0.7 },
|
||||
{ id: 'm2', name: 'DeepSeek-V3', vendor: 'OpenAI Compatible', type: 'text', baseUrl: 'https://api.deepseek.com', apiKey: 'sk-***', temperature: 0.5 },
|
||||
{ id: 'm3', name: 'text-embedding-3-small', vendor: 'OpenAI Compatible', type: 'embedding', baseUrl: 'https://api.openai.com/v1', apiKey: 'sk-***' },
|
||||
{ id: 'm4', name: 'bge-reranker-v2-m3', vendor: 'SiliconFlow', type: 'rerank', baseUrl: 'https://api.siliconflow.cn/v1', apiKey: 'sk-***' },
|
||||
{ id: 'm4', name: 'bge-reranker-v2-m3', vendor: 'OpenAI Compatible', type: 'rerank', baseUrl: 'https://api.siliconflow.cn/v1', apiKey: 'sk-***' },
|
||||
];
|
||||
|
||||
export const mockASRModels: ASRModel[] = [
|
||||
|
||||
Reference in New Issue
Block a user