Update backend api
This commit is contained in:
@@ -4,7 +4,7 @@ from contextlib import asynccontextmanager
|
||||
import os
|
||||
|
||||
from .db import Base, engine
|
||||
from .routers import assistants, history
|
||||
from .routers import assistants, history, knowledge, llm, asr, tools
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -33,6 +33,10 @@ app.add_middleware(
|
||||
# 路由
|
||||
app.include_router(assistants.router, prefix="/api")
|
||||
app.include_router(history.router, prefix="/api")
|
||||
app.include_router(knowledge.router, prefix="/api")
|
||||
app.include_router(llm.router, prefix="/api")
|
||||
app.include_router(asr.router, prefix="/api")
|
||||
app.include_router(tools.router, prefix="/api")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
|
||||
@@ -3,9 +3,15 @@ from fastapi import APIRouter
|
||||
from . import assistants
|
||||
from . import history
|
||||
from . import knowledge
|
||||
from . import llm
|
||||
from . import asr
|
||||
from . import tools
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
router.include_router(assistants.router)
|
||||
router.include_router(history.router)
|
||||
router.include_router(knowledge.router)
|
||||
router.include_router(llm.router)
|
||||
router.include_router(asr.router)
|
||||
router.include_router(tools.router)
|
||||
|
||||
268
api/app/routers/asr.py
Normal file
268
api/app/routers/asr.py
Normal file
@@ -0,0 +1,268 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
import httpx
|
||||
import time
|
||||
import base64
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from ..db import get_db
|
||||
from ..models import ASRModel
|
||||
from ..schemas import (
|
||||
ASRModelCreate, ASRModelUpdate, ASRModelOut,
|
||||
ASRTestRequest, ASRTestResponse, ListResponse
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/asr", tags=["ASR Models"])
|
||||
|
||||
|
||||
# ============ ASR Models CRUD ============
|
||||
@router.get("", response_model=ListResponse)
|
||||
def list_asr_models(
|
||||
language: Optional[str] = None,
|
||||
enabled: Optional[bool] = None,
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取ASR模型列表"""
|
||||
query = db.query(ASRModel)
|
||||
|
||||
if language:
|
||||
query = query.filter(ASRModel.language == language)
|
||||
if enabled is not None:
|
||||
query = query.filter(ASRModel.enabled == enabled)
|
||||
|
||||
total = query.count()
|
||||
models = query.order_by(ASRModel.created_at.desc()) \
|
||||
.offset((page-1)*limit).limit(limit).all()
|
||||
|
||||
return {"total": total, "page": page, "limit": limit, "list": models}
|
||||
|
||||
|
||||
@router.get("/{id}", response_model=ASRModelOut)
|
||||
def get_asr_model(id: str, db: Session = Depends(get_db)):
|
||||
"""获取单个ASR模型详情"""
|
||||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||||
return model
|
||||
|
||||
|
||||
@router.post("", response_model=ASRModelOut)
|
||||
def create_asr_model(data: ASRModelCreate, db: Session = Depends(get_db)):
|
||||
"""创建ASR模型"""
|
||||
asr_model = ASRModel(
|
||||
id=data.id or str(uuid.uuid4())[:8],
|
||||
user_id=1, # 默认用户
|
||||
name=data.name,
|
||||
vendor=data.vendor,
|
||||
language=data.language,
|
||||
base_url=data.base_url,
|
||||
api_key=data.api_key,
|
||||
model_name=data.model_name,
|
||||
hotwords=data.hotwords,
|
||||
enable_punctuation=data.enable_punctuation,
|
||||
enable_normalization=data.enable_normalization,
|
||||
enabled=data.enabled,
|
||||
)
|
||||
db.add(asr_model)
|
||||
db.commit()
|
||||
db.refresh(asr_model)
|
||||
return asr_model
|
||||
|
||||
|
||||
@router.put("/{id}", response_model=ASRModelOut)
|
||||
def update_asr_model(id: str, data: ASRModelUpdate, db: Session = Depends(get_db)):
|
||||
"""更新ASR模型"""
|
||||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(model, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
return model
|
||||
|
||||
|
||||
@router.delete("/{id}")
|
||||
def delete_asr_model(id: str, db: Session = Depends(get_db)):
|
||||
"""删除ASR模型"""
|
||||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||||
db.delete(model)
|
||||
db.commit()
|
||||
return {"message": "Deleted successfully"}
|
||||
|
||||
|
||||
@router.post("/{id}/test", response_model=ASRTestResponse)
|
||||
def test_asr_model(
|
||||
id: str,
|
||||
request: Optional[ASRTestRequest] = None,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""测试ASR模型"""
|
||||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 根据不同的厂商构造不同的请求
|
||||
if model.vendor.lower() in ["siliconflow", "paraformer"]:
|
||||
# SiliconFlow/Paraformer 格式
|
||||
payload = {
|
||||
"model": model.model_name or "paraformer-v2",
|
||||
"input": {},
|
||||
"parameters": {
|
||||
"hotwords": " ".join(model.hotwords) if model.hotwords else "",
|
||||
"enable_punctuation": model.enable_punctuation,
|
||||
"enable_normalization": model.enable_normalization,
|
||||
}
|
||||
}
|
||||
|
||||
# 如果有音频数据
|
||||
if request and request.audio_data:
|
||||
payload["input"]["file_urls"] = []
|
||||
elif request and request.audio_url:
|
||||
payload["input"]["url"] = request.audio_url
|
||||
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
response = client.post(
|
||||
f"{model.base_url}/asr",
|
||||
json=payload,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
elif model.vendor.lower() == "openai":
|
||||
# OpenAI Whisper 格式
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
|
||||
# 准备文件
|
||||
files = {}
|
||||
if request and request.audio_data:
|
||||
audio_bytes = base64.b64decode(request.audio_data)
|
||||
files = {"file": ("audio.wav", audio_bytes, "audio/wav")}
|
||||
data = {"model": model.model_name or "whisper-1"}
|
||||
elif request and request.audio_url:
|
||||
files = {"file": ("audio.wav", httpx.get(request.audio_url).content, "audio/wav")}
|
||||
data = {"model": model.model_name or "whisper-1"}
|
||||
else:
|
||||
return ASRTestResponse(
|
||||
success=False,
|
||||
error="No audio data or URL provided"
|
||||
)
|
||||
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
response = client.post(
|
||||
f"{model.base_url}/audio/transcriptions",
|
||||
files=files,
|
||||
data=data,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
result = {"results": [{"transcript": result.get("text", "")}]}
|
||||
|
||||
else:
|
||||
# 通用格式(可根据需要扩展)
|
||||
return ASRTestResponse(
|
||||
success=False,
|
||||
message=f"Unsupported vendor: {model.vendor}"
|
||||
)
|
||||
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 解析结果
|
||||
if result_data := result.get("results", [{}])[0]:
|
||||
transcript = result_data.get("transcript", "")
|
||||
return ASRTestResponse(
|
||||
success=True,
|
||||
transcript=transcript,
|
||||
language=result_data.get("language", model.language),
|
||||
confidence=result_data.get("confidence"),
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
return ASRTestResponse(
|
||||
success=False,
|
||||
message="No transcript in response",
|
||||
latency_ms=latency_ms
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
return ASRTestResponse(
|
||||
success=False,
|
||||
error=f"HTTP Error: {e.response.status_code} - {e.response.text[:200]}"
|
||||
)
|
||||
except Exception as e:
|
||||
return ASRTestResponse(
|
||||
success=False,
|
||||
error=str(e)[:200]
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{id}/transcribe")
|
||||
def transcribe_audio(
|
||||
id: str,
|
||||
audio_url: Optional[str] = None,
|
||||
audio_data: Optional[str] = None,
|
||||
hotwords: Optional[List[str]] = None,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""转写音频"""
|
||||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||||
|
||||
try:
|
||||
payload = {
|
||||
"model": model.model_name or "paraformer-v2",
|
||||
"input": {},
|
||||
"parameters": {
|
||||
"hotwords": " ".join(hotwords or model.hotwords or []),
|
||||
"enable_punctuation": model.enable_punctuation,
|
||||
"enable_normalization": model.enable_normalization,
|
||||
}
|
||||
}
|
||||
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
|
||||
if audio_url:
|
||||
payload["input"]["url"] = audio_url
|
||||
elif audio_data:
|
||||
payload["input"]["file_urls"] = []
|
||||
|
||||
with httpx.Client(timeout=120.0) as client:
|
||||
response = client.post(
|
||||
f"{model.base_url}/asr",
|
||||
json=payload,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
|
||||
if result_data := result.get("results", [{}])[0]:
|
||||
return {
|
||||
"success": True,
|
||||
"transcript": result_data.get("transcript", ""),
|
||||
"language": result_data.get("language", model.language),
|
||||
"confidence": result_data.get("confidence"),
|
||||
}
|
||||
|
||||
return {"success": False, "error": "No transcript in response"}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -1,6 +1,6 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
@@ -8,7 +8,7 @@ from ..db import get_db
|
||||
from ..models import Assistant, Voice, Workflow
|
||||
from ..schemas import (
|
||||
AssistantCreate, AssistantUpdate, AssistantOut,
|
||||
VoiceOut,
|
||||
VoiceCreate, VoiceUpdate, VoiceOut,
|
||||
WorkflowCreate, WorkflowUpdate, WorkflowOut
|
||||
)
|
||||
|
||||
@@ -16,11 +16,88 @@ router = APIRouter()
|
||||
|
||||
|
||||
# ============ Voices ============
|
||||
@router.get("/voices", response_model=List[VoiceOut])
|
||||
def list_voices(db: Session = Depends(get_db)):
|
||||
@router.get("/voices")
|
||||
def list_voices(
|
||||
vendor: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
gender: Optional[str] = None,
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取声音库列表"""
|
||||
voices = db.query(Voice).all()
|
||||
return voices
|
||||
query = db.query(Voice)
|
||||
if vendor:
|
||||
query = query.filter(Voice.vendor == vendor)
|
||||
if language:
|
||||
query = query.filter(Voice.language == language)
|
||||
if gender:
|
||||
query = query.filter(Voice.gender == gender)
|
||||
|
||||
total = query.count()
|
||||
voices = query.order_by(Voice.created_at.desc()) \
|
||||
.offset((page-1)*limit).limit(limit).all()
|
||||
return {"total": total, "page": page, "limit": limit, "list": voices}
|
||||
|
||||
|
||||
@router.post("/voices", response_model=VoiceOut)
|
||||
def create_voice(data: VoiceCreate, db: Session = Depends(get_db)):
|
||||
"""创建声音"""
|
||||
voice = Voice(
|
||||
id=data.id or str(uuid.uuid4())[:8],
|
||||
user_id=1,
|
||||
name=data.name,
|
||||
vendor=data.vendor,
|
||||
gender=data.gender,
|
||||
language=data.language,
|
||||
description=data.description,
|
||||
model=data.model,
|
||||
voice_key=data.voice_key,
|
||||
speed=data.speed,
|
||||
gain=data.gain,
|
||||
pitch=data.pitch,
|
||||
enabled=data.enabled,
|
||||
)
|
||||
db.add(voice)
|
||||
db.commit()
|
||||
db.refresh(voice)
|
||||
return voice
|
||||
|
||||
|
||||
@router.get("/voices/{id}", response_model=VoiceOut)
|
||||
def get_voice(id: str, db: Session = Depends(get_db)):
|
||||
"""获取单个声音详情"""
|
||||
voice = db.query(Voice).filter(Voice.id == id).first()
|
||||
if not voice:
|
||||
raise HTTPException(status_code=404, detail="Voice not found")
|
||||
return voice
|
||||
|
||||
|
||||
@router.put("/voices/{id}", response_model=VoiceOut)
|
||||
def update_voice(id: str, data: VoiceUpdate, db: Session = Depends(get_db)):
|
||||
"""更新声音"""
|
||||
voice = db.query(Voice).filter(Voice.id == id).first()
|
||||
if not voice:
|
||||
raise HTTPException(status_code=404, detail="Voice not found")
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(voice, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(voice)
|
||||
return voice
|
||||
|
||||
|
||||
@router.delete("/voices/{id}")
|
||||
def delete_voice(id: str, db: Session = Depends(get_db)):
|
||||
"""删除声音"""
|
||||
voice = db.query(Voice).filter(Voice.id == id).first()
|
||||
if not voice:
|
||||
raise HTTPException(status_code=404, detail="Voice not found")
|
||||
db.delete(voice)
|
||||
db.commit()
|
||||
return {"message": "Deleted successfully"}
|
||||
|
||||
|
||||
# ============ Assistants ============
|
||||
@@ -103,10 +180,17 @@ def delete_assistant(id: str, db: Session = Depends(get_db)):
|
||||
|
||||
# ============ Workflows ============
|
||||
@router.get("/workflows", response_model=List[WorkflowOut])
|
||||
def list_workflows(db: Session = Depends(get_db)):
|
||||
def list_workflows(
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取工作流列表"""
|
||||
workflows = db.query(Workflow).all()
|
||||
return workflows
|
||||
query = db.query(Workflow)
|
||||
total = query.count()
|
||||
workflows = query.order_by(Workflow.created_at.desc()) \
|
||||
.offset((page-1)*limit).limit(limit).all()
|
||||
return {"total": total, "page": page, "limit": limit, "list": workflows}
|
||||
|
||||
|
||||
@router.post("/workflows", response_model=WorkflowOut)
|
||||
@@ -129,6 +213,15 @@ def create_workflow(data: WorkflowCreate, db: Session = Depends(get_db)):
|
||||
return workflow
|
||||
|
||||
|
||||
@router.get("/workflows/{id}", response_model=WorkflowOut)
|
||||
def get_workflow(id: str, db: Session = Depends(get_db)):
|
||||
"""获取单个工作流"""
|
||||
workflow = db.query(Workflow).filter(Workflow.id == id).first()
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
return workflow
|
||||
|
||||
|
||||
@router.put("/workflows/{id}", response_model=WorkflowOut)
|
||||
def update_workflow(id: str, data: WorkflowUpdate, db: Session = Depends(get_db)):
|
||||
"""更新工作流"""
|
||||
|
||||
206
api/app/routers/llm.py
Normal file
206
api/app/routers/llm.py
Normal file
@@ -0,0 +1,206 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
import httpx
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from ..db import get_db
|
||||
from ..models import LLMModel
|
||||
from ..schemas import (
|
||||
LLMModelCreate, LLMModelUpdate, LLMModelOut,
|
||||
LLMModelTestResponse, ListResponse
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/llm", tags=["LLM Models"])
|
||||
|
||||
|
||||
# ============ LLM Models CRUD ============
|
||||
@router.get("", response_model=ListResponse)
|
||||
def list_llm_models(
|
||||
model_type: Optional[str] = None,
|
||||
enabled: Optional[bool] = None,
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取LLM模型列表"""
|
||||
query = db.query(LLMModel)
|
||||
|
||||
if model_type:
|
||||
query = query.filter(LLMModel.type == model_type)
|
||||
if enabled is not None:
|
||||
query = query.filter(LLMModel.enabled == enabled)
|
||||
|
||||
total = query.count()
|
||||
models = query.order_by(LLMModel.created_at.desc()) \
|
||||
.offset((page-1)*limit).limit(limit).all()
|
||||
|
||||
return {"total": total, "page": page, "limit": limit, "list": models}
|
||||
|
||||
|
||||
@router.get("/{id}", response_model=LLMModelOut)
|
||||
def get_llm_model(id: str, db: Session = Depends(get_db)):
|
||||
"""获取单个LLM模型详情"""
|
||||
model = db.query(LLMModel).filter(LLMModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="LLM Model not found")
|
||||
return model
|
||||
|
||||
|
||||
@router.post("", response_model=LLMModelOut)
|
||||
def create_llm_model(data: LLMModelCreate, db: Session = Depends(get_db)):
|
||||
"""创建LLM模型"""
|
||||
llm_model = LLMModel(
|
||||
id=data.id or str(uuid.uuid4())[:8],
|
||||
user_id=1, # 默认用户
|
||||
name=data.name,
|
||||
vendor=data.vendor,
|
||||
type=data.type.value if hasattr(data.type, 'value') else data.type,
|
||||
base_url=data.base_url,
|
||||
api_key=data.api_key,
|
||||
model_name=data.model_name,
|
||||
temperature=data.temperature,
|
||||
context_length=data.context_length,
|
||||
enabled=data.enabled,
|
||||
)
|
||||
db.add(llm_model)
|
||||
db.commit()
|
||||
db.refresh(llm_model)
|
||||
return llm_model
|
||||
|
||||
|
||||
@router.put("/{id}", response_model=LLMModelOut)
|
||||
def update_llm_model(id: str, data: LLMModelUpdate, db: Session = Depends(get_db)):
|
||||
"""更新LLM模型"""
|
||||
model = db.query(LLMModel).filter(LLMModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="LLM Model not found")
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(model, field, value)
|
||||
|
||||
model.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
return model
|
||||
|
||||
|
||||
@router.delete("/{id}")
|
||||
def delete_llm_model(id: str, db: Session = Depends(get_db)):
|
||||
"""删除LLM模型"""
|
||||
model = db.query(LLMModel).filter(LLMModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="LLM Model not found")
|
||||
db.delete(model)
|
||||
db.commit()
|
||||
return {"message": "Deleted successfully"}
|
||||
|
||||
|
||||
@router.post("/{id}/test", response_model=LLMModelTestResponse)
|
||||
def test_llm_model(id: str, db: Session = Depends(get_db)):
|
||||
"""测试LLM模型连接"""
|
||||
model = db.query(LLMModel).filter(LLMModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="LLM Model not found")
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
# 构造测试请求
|
||||
test_messages = [{"role": "user", "content": "Hello, please reply with 'OK'."}]
|
||||
|
||||
payload = {
|
||||
"model": model.model_name or "gpt-3.5-turbo",
|
||||
"messages": test_messages,
|
||||
"max_tokens": 10,
|
||||
"temperature": 0.1,
|
||||
}
|
||||
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
|
||||
with httpx.Client(timeout=30.0) as client:
|
||||
response = client.post(
|
||||
f"{model.base_url}/chat/completions",
|
||||
json=payload,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
result = response.json()
|
||||
|
||||
if result.get("choices"):
|
||||
return LLMModelTestResponse(
|
||||
success=True,
|
||||
latency_ms=latency_ms,
|
||||
message="Connection successful"
|
||||
)
|
||||
else:
|
||||
return LLMModelTestResponse(
|
||||
success=False,
|
||||
latency_ms=latency_ms,
|
||||
message="Unexpected response format"
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
return LLMModelTestResponse(
|
||||
success=False,
|
||||
message=f"HTTP Error: {e.response.status_code} - {e.response.text[:200]}"
|
||||
)
|
||||
except Exception as e:
|
||||
return LLMModelTestResponse(
|
||||
success=False,
|
||||
message=str(e)[:200]
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{id}/chat")
|
||||
def chat_with_llm(
|
||||
id: str,
|
||||
message: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""与LLM模型对话"""
|
||||
model = db.query(LLMModel).filter(LLMModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="LLM Model not found")
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": message})
|
||||
|
||||
payload = {
|
||||
"model": model.model_name or "gpt-3.5-turbo",
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens or 1000,
|
||||
"temperature": temperature if temperature is not None else model.temperature or 0.7,
|
||||
}
|
||||
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
response = client.post(
|
||||
f"{model.base_url}/chat/completions",
|
||||
json=payload,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
if choice := result.get("choices", [{}])[0]:
|
||||
return {
|
||||
"success": True,
|
||||
"reply": choice.get("message", {}).get("content", ""),
|
||||
"usage": result.get("usage", {})
|
||||
}
|
||||
return {"success": False, "reply": "", "error": "No response"}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
379
api/app/routers/tools.py
Normal file
379
api/app/routers/tools.py
Normal file
@@ -0,0 +1,379 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional, Dict, Any
|
||||
import time
|
||||
import uuid
|
||||
import httpx
|
||||
|
||||
from ..db import get_db
|
||||
from ..models import LLMModel, ASRModel
|
||||
|
||||
router = APIRouter(prefix="/tools", tags=["Tools & Autotest"])
|
||||
|
||||
|
||||
# ============ Available Tools ============
|
||||
TOOL_REGISTRY = {
|
||||
"search": {
|
||||
"name": "网络搜索",
|
||||
"description": "搜索互联网获取最新信息",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "搜索关键词"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
},
|
||||
"calculator": {
|
||||
"name": "计算器",
|
||||
"description": "执行数学计算",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {"type": "string", "description": "数学表达式,如: 2 + 3 * 4"}
|
||||
},
|
||||
"required": ["expression"]
|
||||
}
|
||||
},
|
||||
"weather": {
|
||||
"name": "天气查询",
|
||||
"description": "查询指定城市的天气",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "城市名称"}
|
||||
},
|
||||
"required": ["city"]
|
||||
}
|
||||
},
|
||||
"translate": {
|
||||
"name": "翻译",
|
||||
"description": "翻译文本到指定语言",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string", "description": "要翻译的文本"},
|
||||
"target_lang": {"type": "string", "description": "目标语言,如: en, ja, ko"}
|
||||
},
|
||||
"required": ["text", "target_lang"]
|
||||
}
|
||||
},
|
||||
"knowledge": {
|
||||
"name": "知识库查询",
|
||||
"description": "从知识库中检索相关信息",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "查询内容"},
|
||||
"kb_id": {"type": "string", "description": "知识库ID"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
},
|
||||
"code_interpreter": {
|
||||
"name": "代码执行",
|
||||
"description": "安全地执行Python代码",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {"type": "string", "description": "要执行的Python代码"}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get("/list")
|
||||
def list_available_tools():
|
||||
"""获取可用的工具列表"""
|
||||
return {"tools": TOOL_REGISTRY}
|
||||
|
||||
|
||||
@router.get("/list/{tool_id}")
|
||||
def get_tool_detail(tool_id: str):
|
||||
"""获取工具详情"""
|
||||
if tool_id not in TOOL_REGISTRY:
|
||||
raise HTTPException(status_code=404, detail="Tool not found")
|
||||
return TOOL_REGISTRY[tool_id]
|
||||
|
||||
|
||||
# ============ Autotest ============
|
||||
class AutotestResult:
|
||||
"""自动测试结果"""
|
||||
|
||||
def __init__(self):
|
||||
self.id = str(uuid.uuid4())[:8]
|
||||
self.started_at = time.time()
|
||||
self.tests = []
|
||||
self.summary = {"passed": 0, "failed": 0, "total": 0}
|
||||
|
||||
def add_test(self, name: str, passed: bool, message: str = "", duration_ms: int = 0):
|
||||
self.tests.append({
|
||||
"name": name,
|
||||
"passed": passed,
|
||||
"message": message,
|
||||
"duration_ms": duration_ms
|
||||
})
|
||||
if passed:
|
||||
self.summary["passed"] += 1
|
||||
else:
|
||||
self.summary["failed"] += 1
|
||||
self.summary["total"] += 1
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"started_at": self.started_at,
|
||||
"duration_ms": int((time.time() - self.started_at) * 1000),
|
||||
"tests": self.tests,
|
||||
"summary": self.summary
|
||||
}
|
||||
|
||||
|
||||
@router.post("/autotest")
|
||||
def run_autotest(
|
||||
llm_model_id: Optional[str] = None,
|
||||
asr_model_id: Optional[str] = None,
|
||||
test_llm: bool = True,
|
||||
test_asr: bool = True,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""运行自动测试"""
|
||||
result = AutotestResult()
|
||||
|
||||
# 测试 LLM 模型
|
||||
if test_llm and llm_model_id:
|
||||
_test_llm_model(db, llm_model_id, result)
|
||||
|
||||
# 测试 ASR 模型
|
||||
if test_asr and asr_model_id:
|
||||
_test_asr_model(db, asr_model_id, result)
|
||||
|
||||
# 测试 TTS 功能(需要时可添加)
|
||||
if test_llm and not llm_model_id:
|
||||
result.add_test(
|
||||
"LLM Model Check",
|
||||
False,
|
||||
"No LLM model ID provided"
|
||||
)
|
||||
|
||||
if test_asr and not asr_model_id:
|
||||
result.add_test(
|
||||
"ASR Model Check",
|
||||
False,
|
||||
"No ASR model ID provided"
|
||||
)
|
||||
|
||||
return result.to_dict()
|
||||
|
||||
|
||||
@router.post("/autotest/llm/{model_id}")
|
||||
def autotest_llm_model(model_id: str, db: Session = Depends(get_db)):
|
||||
"""测试单个LLM模型"""
|
||||
result = AutotestResult()
|
||||
_test_llm_model(db, model_id, result)
|
||||
return result.to_dict()
|
||||
|
||||
|
||||
@router.post("/autotest/asr/{model_id}")
|
||||
def autotest_asr_model(model_id: str, db: Session = Depends(get_db)):
|
||||
"""测试单个ASR模型"""
|
||||
result = AutotestResult()
|
||||
_test_asr_model(db, model_id, result)
|
||||
return result.to_dict()
|
||||
|
||||
|
||||
def _test_llm_model(db: Session, model_id: str, result: AutotestResult):
|
||||
"""内部方法:测试LLM模型"""
|
||||
start_time = time.time()
|
||||
|
||||
# 1. 检查模型是否存在
|
||||
model = db.query(LLMModel).filter(LLMModel.id == model_id).first()
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
if not model:
|
||||
result.add_test("Model Existence", False, f"Model {model_id} not found", duration_ms)
|
||||
return
|
||||
|
||||
result.add_test("Model Existence", True, f"Found model: {model.name}", duration_ms)
|
||||
|
||||
# 2. 测试连接
|
||||
test_start = time.time()
|
||||
try:
|
||||
test_messages = [{"role": "user", "content": "Reply with 'OK'."}]
|
||||
payload = {
|
||||
"model": model.model_name or "gpt-3.5-turbo",
|
||||
"messages": test_messages,
|
||||
"max_tokens": 10,
|
||||
"temperature": 0.1,
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
|
||||
with httpx.Client(timeout=30.0) as client:
|
||||
response = client.post(
|
||||
f"{model.base_url}/chat/completions",
|
||||
json=payload,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result_text = response.json()
|
||||
latency_ms = int((time.time() - test_start) * 1000)
|
||||
|
||||
if result_text.get("choices"):
|
||||
result.add_test("API Connection", True, f"Latency: {latency_ms}ms", latency_ms)
|
||||
else:
|
||||
result.add_test("API Connection", False, "Empty response", latency_ms)
|
||||
|
||||
except Exception as e:
|
||||
latency_ms = int((time.time() - test_start) * 1000)
|
||||
result.add_test("API Connection", False, str(e)[:200], latency_ms)
|
||||
|
||||
# 3. 检查模型配置
|
||||
if model.temperature is not None:
|
||||
result.add_test("Temperature Setting", True, f"temperature={model.temperature}")
|
||||
else:
|
||||
result.add_test("Temperature Setting", True, "Using default")
|
||||
|
||||
# 4. 测试流式响应(可选)
|
||||
if model.type == "text":
|
||||
test_start = time.time()
|
||||
try:
|
||||
with httpx.Client(timeout=30.0) as client:
|
||||
with client.stream(
|
||||
"POST",
|
||||
f"{model.base_url}/chat/completions",
|
||||
json={
|
||||
"model": model.model_name or "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Count from 1 to 3."}],
|
||||
"stream": True,
|
||||
},
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
chunk_count = 0
|
||||
for _ in response.iter_bytes():
|
||||
chunk_count += 1
|
||||
|
||||
latency_ms = int((time.time() - test_start) * 1000)
|
||||
result.add_test("Streaming Support", True, f"Received {chunk_count} chunks", latency_ms)
|
||||
except Exception as e:
|
||||
latency_ms = int((time.time() - test_start) * 1000)
|
||||
result.add_test("Streaming Support", False, str(e)[:200], latency_ms)
|
||||
|
||||
|
||||
def _test_asr_model(db: Session, model_id: str, result: AutotestResult):
|
||||
"""内部方法:测试ASR模型"""
|
||||
start_time = time.time()
|
||||
|
||||
# 1. 检查模型是否存在
|
||||
model = db.query(ASRModel).filter(ASRModel.id == model_id).first()
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
if not model:
|
||||
result.add_test("Model Existence", False, f"Model {model_id} not found", duration_ms)
|
||||
return
|
||||
|
||||
result.add_test("Model Existence", True, f"Found model: {model.name}", duration_ms)
|
||||
|
||||
# 2. 测试配置
|
||||
if model.hotwords:
|
||||
result.add_test("Hotwords Config", True, f"Hotwords: {len(model.hotwords)} words")
|
||||
else:
|
||||
result.add_test("Hotwords Config", True, "No hotwords configured")
|
||||
|
||||
# 3. 测试API可用性
|
||||
test_start = time.time()
|
||||
try:
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
|
||||
with httpx.Client(timeout=30.0) as client:
|
||||
if model.vendor.lower() in ["siliconflow", "paraformer"]:
|
||||
response = client.get(
|
||||
f"{model.base_url}/asr",
|
||||
headers=headers
|
||||
)
|
||||
elif model.vendor.lower() == "openai":
|
||||
response = client.get(
|
||||
f"{model.base_url}/audio/models",
|
||||
headers=headers
|
||||
)
|
||||
else:
|
||||
# 通用健康检查
|
||||
response = client.get(
|
||||
f"{model.base_url}/health",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
latency_ms = int((time.time() - test_start) * 1000)
|
||||
|
||||
if response.status_code in [200, 405]: # 405 = method not allowed but endpoint exists
|
||||
result.add_test("API Availability", True, f"Status: {response.status_code}", latency_ms)
|
||||
else:
|
||||
result.add_test("API Availability", False, f"Status: {response.status_code}", latency_ms)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
latency_ms = int((time.time() - test_start) * 1000)
|
||||
result.add_test("API Availability", False, "Connection timeout", latency_ms)
|
||||
except Exception as e:
|
||||
latency_ms = int((time.time() - test_start) * 1000)
|
||||
result.add_test("API Availability", False, str(e)[:200], latency_ms)
|
||||
|
||||
# 4. 检查语言配置
|
||||
if model.language in ["zh", "en", "Multi-lingual"]:
|
||||
result.add_test("Language Config", True, f"Language: {model.language}")
|
||||
else:
|
||||
result.add_test("Language Config", False, f"Unknown language: {model.language}")
|
||||
|
||||
|
||||
# ============ Quick Health Check ============
|
||||
@router.get("/health")
|
||||
def health_check():
|
||||
"""快速健康检查"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"timestamp": time.time(),
|
||||
"tools": list(TOOL_REGISTRY.keys())
|
||||
}
|
||||
|
||||
|
||||
@router.post("/test-message")
|
||||
def send_test_message(
|
||||
llm_model_id: str,
|
||||
message: str = "Hello, this is a test message.",
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""发送测试消息"""
|
||||
model = db.query(LLMModel).filter(LLMModel.id == llm_model_id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="LLM Model not found")
|
||||
|
||||
try:
|
||||
payload = {
|
||||
"model": model.model_name or "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": message}],
|
||||
"max_tokens": 500,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
response = client.post(
|
||||
f"{model.base_url}/chat/completions",
|
||||
json=payload,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
reply = result.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"reply": reply,
|
||||
"usage": result.get("usage", {})
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
409
api/docs/asr.md
Normal file
409
api/docs/asr.md
Normal file
@@ -0,0 +1,409 @@
|
||||
# 语音识别 (ASR Model) API
|
||||
|
||||
语音识别 API 用于管理语音识别模型的配置和调用。
|
||||
|
||||
## 基础信息
|
||||
|
||||
| 项目 | 值 |
|
||||
|------|-----|
|
||||
| Base URL | `/api/v1/asr` |
|
||||
| 认证方式 | Bearer Token (预留) |
|
||||
|
||||
---
|
||||
|
||||
## 数据模型
|
||||
|
||||
### ASRModel
|
||||
|
||||
```typescript
|
||||
interface ASRModel {
|
||||
id: string; // 模型唯一标识 (8位UUID)
|
||||
user_id: number; // 所属用户ID
|
||||
name: string; // 模型显示名称
|
||||
vendor: string; // 供应商: "OpenAI" | "SiliconFlow" | "Paraformer" | 等
|
||||
language: string; // 识别语言: "zh" | "en" | "Multi-lingual"
|
||||
base_url: string; // API Base URL
|
||||
api_key: string; // API Key
|
||||
model_name?: string; // 模型名称,如 "whisper-1" | "paraformer-v2"
|
||||
hotwords?: string[]; // 热词列表
|
||||
enable_punctuation: boolean; // 是否启用标点
|
||||
enable_normalization: boolean; // 是否启用文本规范化
|
||||
enabled: boolean; // 是否启用
|
||||
created_at: string;
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## API 端点
|
||||
|
||||
### 1. 获取 ASR 模型列表
|
||||
|
||||
```http
|
||||
GET /api/v1/asr
|
||||
```
|
||||
|
||||
**Query Parameters:**
|
||||
|
||||
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|
||||
|------|------|------|--------|------|
|
||||
| language | string | 否 | - | 过滤语言: "zh" \| "en" \| "Multi-lingual" |
|
||||
| enabled | boolean | 否 | - | 过滤启用状态 |
|
||||
| page | int | 否 | 1 | 页码 |
|
||||
| limit | int | 否 | 50 | 每页数量 |
|
||||
|
||||
**Response:**
|
||||
|
||||
```json
|
||||
{
|
||||
"total": 3,
|
||||
"page": 1,
|
||||
"limit": 50,
|
||||
"list": [
|
||||
{
|
||||
"id": "abc12345",
|
||||
"user_id": 1,
|
||||
"name": "Whisper 多语种识别",
|
||||
"vendor": "OpenAI",
|
||||
"language": "Multi-lingual",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-***",
|
||||
"model_name": "whisper-1",
|
||||
"enable_punctuation": true,
|
||||
"enable_normalization": true,
|
||||
"enabled": true,
|
||||
"created_at": "2024-01-15T10:30:00Z"
|
||||
},
|
||||
{
|
||||
"id": "def67890",
|
||||
"user_id": 1,
|
||||
"name": "SenseVoice 中文识别",
|
||||
"vendor": "SiliconFlow",
|
||||
"language": "zh",
|
||||
"base_url": "https://api.siliconflow.cn/v1",
|
||||
"api_key": "sf-***",
|
||||
"model_name": "paraformer-v2",
|
||||
"hotwords": ["小助手", "帮我"],
|
||||
"enable_punctuation": true,
|
||||
"enable_normalization": true,
|
||||
"enabled": true,
|
||||
"created_at": "2024-01-15T10:30:00Z"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 2. 获取单个 ASR 模型详情
|
||||
|
||||
```http
|
||||
GET /api/v1/asr/{id}
|
||||
```
|
||||
|
||||
**Path Parameters:**
|
||||
|
||||
| 参数 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| id | string | 模型ID |
|
||||
|
||||
**Response:**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "abc12345",
|
||||
"user_id": 1,
|
||||
"name": "Whisper 多语种识别",
|
||||
"vendor": "OpenAI",
|
||||
"language": "Multi-lingual",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-***",
|
||||
"model_name": "whisper-1",
|
||||
"hotwords": [],
|
||||
"enable_punctuation": true,
|
||||
"enable_normalization": true,
|
||||
"enabled": true,
|
||||
"created_at": "2024-01-15T10:30:00Z"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 3. 创建 ASR 模型
|
||||
|
||||
```http
|
||||
POST /api/v1/asr
|
||||
```
|
||||
|
||||
**Request Body:**
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "SenseVoice 中文识别",
|
||||
"vendor": "SiliconFlow",
|
||||
"language": "zh",
|
||||
"base_url": "https://api.siliconflow.cn/v1",
|
||||
"api_key": "sk-your-api-key",
|
||||
"model_name": "paraformer-v2",
|
||||
"hotwords": ["小助手", "帮我"],
|
||||
"enable_punctuation": true,
|
||||
"enable_normalization": true,
|
||||
"enabled": true
|
||||
}
|
||||
```
|
||||
|
||||
**Fields 说明:**
|
||||
|
||||
| 字段 | 类型 | 必填 | 说明 |
|
||||
|------|------|------|------|
|
||||
| name | string | 是 | 模型显示名称 |
|
||||
| vendor | string | 是 | 供应商: "OpenAI" / "SiliconFlow" / "Paraformer" |
|
||||
| language | string | 是 | 语言: "zh" / "en" / "Multi-lingual" |
|
||||
| base_url | string | 是 | API Base URL |
|
||||
| api_key | string | 是 | API Key |
|
||||
| model_name | string | 否 | 模型名称 |
|
||||
| hotwords | string[] | 否 | 热词列表,提升识别准确率 |
|
||||
| enable_punctuation | boolean | 否 | 是否输出标点,默认 true |
|
||||
| enable_normalization | boolean | 否 | 是否文本规范化,默认 true |
|
||||
| enabled | boolean | 否 | 是否启用,默认 true |
|
||||
| id | string | 否 | 指定模型ID,默认自动生成 |
|
||||
|
||||
---
|
||||
|
||||
### 4. 更新 ASR 模型
|
||||
|
||||
```http
|
||||
PUT /api/v1/asr/{id}
|
||||
```
|
||||
|
||||
**Request Body:** (部分更新)
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "Whisper-1 优化版",
|
||||
"language": "zh",
|
||||
"enable_punctuation": true,
|
||||
"hotwords": ["新词1", "新词2"]
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 5. 删除 ASR 模型
|
||||
|
||||
```http
|
||||
DELETE /api/v1/asr/{id}
|
||||
```
|
||||
|
||||
**Response:**
|
||||
|
||||
```json
|
||||
{
|
||||
"message": "Deleted successfully"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 6. 测试 ASR 模型
|
||||
|
||||
```http
|
||||
POST /api/v1/asr/{id}/test
|
||||
```
|
||||
|
||||
**Request Body:**
|
||||
|
||||
```json
|
||||
{
|
||||
"audio_url": "https://example.com/test-audio.wav"
|
||||
}
|
||||
```
|
||||
|
||||
或使用 Base64 编码的音频数据:
|
||||
|
||||
```json
|
||||
{
|
||||
"audio_data": "UklGRi..."
|
||||
}
|
||||
```
|
||||
|
||||
**Response (成功):**
|
||||
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"transcript": "您好,请问有什么可以帮助您?",
|
||||
"language": "zh",
|
||||
"confidence": 0.95,
|
||||
"latency_ms": 500
|
||||
}
|
||||
```
|
||||
|
||||
**Response (失败):**
|
||||
|
||||
```json
|
||||
{
|
||||
"success": false,
|
||||
"error": "HTTP Error: 401 - Unauthorized"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 7. 转写音频
|
||||
|
||||
```http
|
||||
POST /api/v1/asr/{id}/transcribe
|
||||
```
|
||||
|
||||
**Query Parameters:**
|
||||
|
||||
| 参数 | 类型 | 必填 | 说明 |
|
||||
|------|------|------|------|
|
||||
| audio_url | string | 否* | 音频文件URL |
|
||||
| audio_data | string | 否* | Base64编码的音频数据 |
|
||||
| hotwords | string[] | 否 | 热词列表 |
|
||||
|
||||
*二选一,至少提供一个
|
||||
|
||||
**Response:**
|
||||
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"transcript": "您好,请问有什么可以帮助您?",
|
||||
"language": "zh",
|
||||
"confidence": 0.95
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Schema 定义
|
||||
|
||||
```python
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
class ASRLanguage(str, Enum):
|
||||
ZH = "zh"
|
||||
EN = "en"
|
||||
MULTILINGUAL = "Multi-lingual"
|
||||
|
||||
class ASRModelBase(BaseModel):
|
||||
name: str
|
||||
vendor: str
|
||||
language: str # "zh" | "en" | "Multi-lingual"
|
||||
base_url: str
|
||||
api_key: str
|
||||
model_name: Optional[str] = None
|
||||
hotwords: List[str] = []
|
||||
enable_punctuation: bool = True
|
||||
enable_normalization: bool = True
|
||||
enabled: bool = True
|
||||
|
||||
class ASRModelCreate(ASRModelBase):
|
||||
id: Optional[str] = None
|
||||
|
||||
class ASRModelUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
model_name: Optional[str] = None
|
||||
hotwords: Optional[List[str]] = None
|
||||
enable_punctuation: Optional[bool] = None
|
||||
enable_normalization: Optional[bool] = None
|
||||
enabled: Optional[bool] = None
|
||||
|
||||
class ASRModelOut(ASRModelBase):
|
||||
id: str
|
||||
user_id: int
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class ASRTestRequest(BaseModel):
|
||||
audio_url: Optional[str] = None
|
||||
audio_data: Optional[str] = None # base64 encoded
|
||||
|
||||
class ASRTestResponse(BaseModel):
|
||||
success: bool
|
||||
transcript: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
confidence: Optional[float] = None
|
||||
latency_ms: Optional[int] = None
|
||||
error: Optional[str] = None
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 供应商配置示例
|
||||
|
||||
### OpenAI Whisper
|
||||
|
||||
```json
|
||||
{
|
||||
"vendor": "OpenAI",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-xxx",
|
||||
"model_name": "whisper-1",
|
||||
"language": "Multi-lingual",
|
||||
"enable_punctuation": true,
|
||||
"enable_normalization": true
|
||||
}
|
||||
```
|
||||
|
||||
### SiliconFlow Paraformer
|
||||
|
||||
```json
|
||||
{
|
||||
"vendor": "SiliconFlow",
|
||||
"base_url": "https://api.siliconflow.cn/v1",
|
||||
"api_key": "sf-xxx",
|
||||
"model_name": "paraformer-v2",
|
||||
"language": "zh",
|
||||
"hotwords": ["产品名称", "公司名"],
|
||||
"enable_punctuation": true,
|
||||
"enable_normalization": true
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 单元测试
|
||||
|
||||
项目包含完整的单元测试,位于 `api/tests/test_asr.py`。
|
||||
|
||||
### 测试用例概览
|
||||
|
||||
| 测试方法 | 说明 |
|
||||
|----------|------|
|
||||
| test_get_asr_models_empty | 空数据库获取测试 |
|
||||
| test_create_asr_model | 创建模型测试 |
|
||||
| test_create_asr_model_minimal | 最小数据创建测试 |
|
||||
| test_get_asr_model_by_id | 获取单个模型测试 |
|
||||
| test_get_asr_model_not_found | 获取不存在模型测试 |
|
||||
| test_update_asr_model | 更新模型测试 |
|
||||
| test_delete_asr_model | 删除模型测试 |
|
||||
| test_list_asr_models_with_pagination | 分页测试 |
|
||||
| test_filter_asr_models_by_language | 按语言过滤测试 |
|
||||
| test_filter_asr_models_by_enabled | 按启用状态过滤测试 |
|
||||
| test_create_asr_model_with_hotwords | 热词配置测试 |
|
||||
| test_test_asr_model_siliconflow | SiliconFlow 供应商测试 |
|
||||
| test_test_asr_model_openai | OpenAI 供应商测试 |
|
||||
| test_different_asr_languages | 多语言测试 |
|
||||
| test_different_asr_vendors | 多供应商测试 |
|
||||
|
||||
### 运行测试
|
||||
|
||||
```bash
|
||||
# 运行 ASR 相关测试
|
||||
pytest api/tests/test_asr.py -v
|
||||
|
||||
# 运行所有测试
|
||||
pytest api/tests/ -v
|
||||
```
|
||||
@@ -7,9 +7,9 @@
|
||||
| 模块 | 文件 | 说明 |
|
||||
|------|------|------|
|
||||
| 小助手 | [assistant.md](./assistant.md) | AI 助手管理 |
|
||||
| 模型接入 | [model-access.md](./model-access.md) | LLM/ASR/TTS 模型配置 |
|
||||
| 语音识别 | [speech-recognition.md](./speech-recognition.md) | ASR 模型配置 |
|
||||
| 声音资源 | [voice-resources.md](./voice-resources.md) | TTS 声音库管理 |
|
||||
| LLM 模型 | [llm.md](./llm.md) | LLM 模型配置与管理 |
|
||||
| ASR 模型 | [asr.md](./asr.md) | 语音识别模型配置 |
|
||||
| 工具与测试 | [tools.md](./tools.md) | 工具列表与自动测试 |
|
||||
| 历史记录 | [history-records.md](./history-records.md) | 通话记录和转写 |
|
||||
|
||||
---
|
||||
|
||||
401
api/docs/llm.md
Normal file
401
api/docs/llm.md
Normal file
@@ -0,0 +1,401 @@
|
||||
# LLM 模型 (LLM Model) API
|
||||
|
||||
LLM 模型 API 用于管理大语言模型的配置和调用。
|
||||
|
||||
## 基础信息
|
||||
|
||||
| 项目 | 值 |
|
||||
|------|-----|
|
||||
| Base URL | `/api/v1/llm` |
|
||||
| 认证方式 | Bearer Token (预留) |
|
||||
|
||||
---
|
||||
|
||||
## 数据模型
|
||||
|
||||
### LLMModel
|
||||
|
||||
```typescript
|
||||
interface LLMModel {
|
||||
id: string; // 模型唯一标识 (8位UUID)
|
||||
user_id: number; // 所属用户ID
|
||||
name: string; // 模型显示名称
|
||||
vendor: string; // 供应商: "OpenAI" | "SiliconFlow" | "Dify" | "FastGPT" | 等
|
||||
type: string; // 类型: "text" | "embedding" | "rerank"
|
||||
base_url: string; // API Base URL
|
||||
api_key: string; // API Key
|
||||
model_name?: string; // 实际模型名称,如 "gpt-4o"
|
||||
temperature?: number; // 温度参数 (0-2)
|
||||
context_length?: int; // 上下文长度
|
||||
enabled: boolean; // 是否启用
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## API 端点
|
||||
|
||||
### 1. 获取 LLM 模型列表
|
||||
|
||||
```http
|
||||
GET /api/v1/llm
|
||||
```
|
||||
|
||||
**Query Parameters:**
|
||||
|
||||
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|
||||
|------|------|------|--------|------|
|
||||
| model_type | string | 否 | - | 过滤类型: "text" \| "embedding" \| "rerank" |
|
||||
| enabled | boolean | 否 | - | 过滤启用状态 |
|
||||
| page | int | 否 | 1 | 页码 |
|
||||
| limit | int | 否 | 50 | 每页数量 |
|
||||
|
||||
**Response:**
|
||||
|
||||
```json
|
||||
{
|
||||
"total": 5,
|
||||
"page": 1,
|
||||
"limit": 50,
|
||||
"list": [
|
||||
{
|
||||
"id": "abc12345",
|
||||
"user_id": 1,
|
||||
"name": "GPT-4o",
|
||||
"vendor": "OpenAI",
|
||||
"type": "text",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-***",
|
||||
"model_name": "gpt-4o",
|
||||
"temperature": 0.7,
|
||||
"context_length": 128000,
|
||||
"enabled": true,
|
||||
"created_at": "2024-01-15T10:30:00Z",
|
||||
"updated_at": "2024-01-15T10:30:00Z"
|
||||
},
|
||||
{
|
||||
"id": "def67890",
|
||||
"user_id": 1,
|
||||
"name": "Embedding-3-Small",
|
||||
"vendor": "OpenAI",
|
||||
"type": "embedding",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-***",
|
||||
"model_name": "text-embedding-3-small",
|
||||
"enabled": true
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 2. 获取单个 LLM 模型详情
|
||||
|
||||
```http
|
||||
GET /api/v1/llm/{id}
|
||||
```
|
||||
|
||||
**Path Parameters:**
|
||||
|
||||
| 参数 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| id | string | 模型ID |
|
||||
|
||||
**Response:**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "abc12345",
|
||||
"user_id": 1,
|
||||
"name": "GPT-4o",
|
||||
"vendor": "OpenAI",
|
||||
"type": "text",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-***",
|
||||
"model_name": "gpt-4o",
|
||||
"temperature": 0.7,
|
||||
"context_length": 128000,
|
||||
"enabled": true,
|
||||
"created_at": "2024-01-15T10:30:00Z",
|
||||
"updated_at": "2024-01-15T10:30:00Z"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 3. 创建 LLM 模型
|
||||
|
||||
```http
|
||||
POST /api/v1/llm
|
||||
```
|
||||
|
||||
**Request Body:**
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "GPT-4o",
|
||||
"vendor": "OpenAI",
|
||||
"type": "text",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-your-api-key",
|
||||
"model_name": "gpt-4o",
|
||||
"temperature": 0.7,
|
||||
"context_length": 128000,
|
||||
"enabled": true
|
||||
}
|
||||
```
|
||||
|
||||
**Fields 说明:**
|
||||
|
||||
| 字段 | 类型 | 必填 | 说明 |
|
||||
|------|------|------|------|
|
||||
| name | string | 是 | 模型显示名称 |
|
||||
| vendor | string | 是 | 供应商名称 |
|
||||
| type | string | 是 | 模型类型: "text" / "embedding" / "rerank" |
|
||||
| base_url | string | 是 | API Base URL |
|
||||
| api_key | string | 是 | API Key |
|
||||
| model_name | string | 否 | 实际模型名称 |
|
||||
| temperature | number | 否 | 温度参数,默认 0.7 |
|
||||
| context_length | int | 否 | 上下文长度 |
|
||||
| enabled | boolean | 否 | 是否启用,默认 true |
|
||||
| id | string | 否 | 指定模型ID,默认自动生成 |
|
||||
|
||||
---
|
||||
|
||||
### 4. 更新 LLM 模型
|
||||
|
||||
```http
|
||||
PUT /api/v1/llm/{id}
|
||||
```
|
||||
|
||||
**Request Body:** (部分更新)
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "GPT-4o-Updated",
|
||||
"temperature": 0.8,
|
||||
"enabled": false
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 5. 删除 LLM 模型
|
||||
|
||||
```http
|
||||
DELETE /api/v1/llm/{id}
|
||||
```
|
||||
|
||||
**Response:**
|
||||
|
||||
```json
|
||||
{
|
||||
"message": "Deleted successfully"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 6. 测试 LLM 模型连接
|
||||
|
||||
```http
|
||||
POST /api/v1/llm/{id}/test
|
||||
```
|
||||
|
||||
**Response:**
|
||||
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"latency_ms": 150,
|
||||
"message": "Connection successful"
|
||||
}
|
||||
```
|
||||
|
||||
**错误响应:**
|
||||
|
||||
```json
|
||||
{
|
||||
"success": false,
|
||||
"latency_ms": 200,
|
||||
"message": "HTTP Error: 401 - Unauthorized"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 7. 与 LLM 模型对话
|
||||
|
||||
```http
|
||||
POST /api/v1/llm/{id}/chat
|
||||
```
|
||||
|
||||
**Query Parameters:**
|
||||
|
||||
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|
||||
|------|------|------|--------|------|
|
||||
| message | string | 是 | - | 用户消息 |
|
||||
| system_prompt | string | 否 | - | 系统提示词 |
|
||||
| max_tokens | int | 否 | 1000 | 最大生成token数 |
|
||||
| temperature | number | 否 | 模型配置 | 温度参数 |
|
||||
|
||||
**Response:**
|
||||
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"reply": "您好!有什么可以帮助您的?",
|
||||
"usage": {
|
||||
"prompt_tokens": 20,
|
||||
"completion_tokens": 15,
|
||||
"total_tokens": 35
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Schema 定义
|
||||
|
||||
```python
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
class LLMModelType(str, Enum):
|
||||
TEXT = "text"
|
||||
EMBEDDING = "embedding"
|
||||
RERANK = "rerank"
|
||||
|
||||
class LLMModelBase(BaseModel):
|
||||
name: str
|
||||
vendor: str
|
||||
type: LLMModelType
|
||||
base_url: str
|
||||
api_key: str
|
||||
model_name: Optional[str] = None
|
||||
temperature: Optional[float] = None
|
||||
context_length: Optional[int] = None
|
||||
enabled: bool = True
|
||||
|
||||
class LLMModelCreate(LLMModelBase):
|
||||
id: Optional[str] = None
|
||||
|
||||
class LLMModelUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
vendor: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
model_name: Optional[str] = None
|
||||
temperature: Optional[float] = None
|
||||
context_length: Optional[int] = None
|
||||
enabled: Optional[bool] = None
|
||||
|
||||
class LLMModelOut(LLMModelBase):
|
||||
id: str
|
||||
user_id: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class LLMModelTestResponse(BaseModel):
|
||||
success: bool
|
||||
latency_ms: int
|
||||
message: str
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 供应商配置示例
|
||||
|
||||
### OpenAI
|
||||
|
||||
```json
|
||||
{
|
||||
"vendor": "OpenAI",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-xxx",
|
||||
"model_name": "gpt-4o",
|
||||
"type": "text",
|
||||
"temperature": 0.7
|
||||
}
|
||||
```
|
||||
|
||||
### SiliconFlow
|
||||
|
||||
```json
|
||||
{
|
||||
"vendor": "SiliconFlow",
|
||||
"base_url": "https://api.siliconflow.com/v1",
|
||||
"api_key": "sf-xxx",
|
||||
"model_name": "deepseek-v3",
|
||||
"type": "text",
|
||||
"temperature": 0.7
|
||||
}
|
||||
```
|
||||
|
||||
### Dify
|
||||
|
||||
```json
|
||||
{
|
||||
"vendor": "Dify",
|
||||
"base_url": "https://your-dify.domain.com/v1",
|
||||
"api_key": "app-xxx",
|
||||
"model_name": "gpt-4",
|
||||
"type": "text"
|
||||
}
|
||||
```
|
||||
|
||||
### Embedding 模型
|
||||
|
||||
```json
|
||||
{
|
||||
"vendor": "OpenAI",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-xxx",
|
||||
"model_name": "text-embedding-3-small",
|
||||
"type": "embedding"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 单元测试
|
||||
|
||||
项目包含完整的单元测试,位于 `api/tests/test_llm.py`。
|
||||
|
||||
### 测试用例概览
|
||||
|
||||
| 测试方法 | 说明 |
|
||||
|----------|------|
|
||||
| test_get_llm_models_empty | 空数据库获取测试 |
|
||||
| test_create_llm_model | 创建模型测试 |
|
||||
| test_create_llm_model_minimal | 最小数据创建测试 |
|
||||
| test_get_llm_model_by_id | 获取单个模型测试 |
|
||||
| test_get_llm_model_not_found | 获取不存在模型测试 |
|
||||
| test_update_llm_model | 更新模型测试 |
|
||||
| test_delete_llm_model | 删除模型测试 |
|
||||
| test_list_llm_models_with_pagination | 分页测试 |
|
||||
| test_filter_llm_models_by_type | 按类型过滤测试 |
|
||||
| test_filter_llm_models_by_enabled | 按启用状态过滤测试 |
|
||||
| test_create_llm_model_with_all_fields | 全字段创建测试 |
|
||||
| test_test_llm_model_success | 测试连接成功测试 |
|
||||
| test_test_llm_model_failure | 测试连接失败测试 |
|
||||
| test_different_llm_vendors | 多供应商测试 |
|
||||
| test_embedding_llm_model | Embedding 模型测试 |
|
||||
|
||||
### 运行测试
|
||||
|
||||
```bash
|
||||
# 运行 LLM 相关测试
|
||||
pytest api/tests/test_llm.py -v
|
||||
|
||||
# 运行所有测试
|
||||
pytest api/tests/ -v
|
||||
```
|
||||
445
api/docs/tools.md
Normal file
445
api/docs/tools.md
Normal file
@@ -0,0 +1,445 @@
|
||||
# 工具与自动测试 (Tools & Autotest) API
|
||||
|
||||
工具与自动测试 API 用于管理可用工具列表和自动测试功能。
|
||||
|
||||
## 基础信息
|
||||
|
||||
| 项目 | 值 |
|
||||
|------|-----|
|
||||
| Base URL | `/api/v1/tools` |
|
||||
| 认证方式 | Bearer Token (预留) |
|
||||
|
||||
---
|
||||
|
||||
## 可用工具 (Tool Registry)
|
||||
|
||||
系统内置以下工具:
|
||||
|
||||
| 工具ID | 名称 | 说明 |
|
||||
|--------|------|------|
|
||||
| search | 网络搜索 | 搜索互联网获取最新信息 |
|
||||
| calculator | 计算器 | 执行数学计算 |
|
||||
| weather | 天气查询 | 查询指定城市的天气 |
|
||||
| translate | 翻译 | 翻译文本到指定语言 |
|
||||
| knowledge | 知识库查询 | 从知识库中检索相关信息 |
|
||||
| code_interpreter | 代码执行 | 安全地执行Python代码 |
|
||||
|
||||
---
|
||||
|
||||
## API 端点
|
||||
|
||||
### 1. 获取可用工具列表
|
||||
|
||||
```http
|
||||
GET /api/v1/tools/list
|
||||
```
|
||||
|
||||
**Response:**
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": {
|
||||
"search": {
|
||||
"name": "网络搜索",
|
||||
"description": "搜索互联网获取最新信息",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "搜索关键词"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
},
|
||||
"calculator": {
|
||||
"name": "计算器",
|
||||
"description": "执行数学计算",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {"type": "string", "description": "数学表达式,如: 2 + 3 * 4"}
|
||||
},
|
||||
"required": ["expression"]
|
||||
}
|
||||
},
|
||||
"weather": {
|
||||
"name": "天气查询",
|
||||
"description": "查询指定城市的天气",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "城市名称"}
|
||||
},
|
||||
"required": ["city"]
|
||||
}
|
||||
},
|
||||
"translate": {
|
||||
"name": "翻译",
|
||||
"description": "翻译文本到指定语言",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string", "description": "要翻译的文本"},
|
||||
"target_lang": {"type": "string", "description": "目标语言,如: en, ja, ko"}
|
||||
},
|
||||
"required": ["text", "target_lang"]
|
||||
}
|
||||
},
|
||||
"knowledge": {
|
||||
"name": "知识库查询",
|
||||
"description": "从知识库中检索相关信息",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "查询内容"},
|
||||
"kb_id": {"type": "string", "description": "知识库ID"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
},
|
||||
"code_interpreter": {
|
||||
"name": "代码执行",
|
||||
"description": "安全地执行Python代码",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {"type": "string", "description": "要执行的Python代码"}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 2. 获取工具详情
|
||||
|
||||
```http
|
||||
GET /api/v1/tools/list/{tool_id}
|
||||
```
|
||||
|
||||
**Path Parameters:**
|
||||
|
||||
| 参数 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| tool_id | string | 工具ID |
|
||||
|
||||
**Response:**
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "计算器",
|
||||
"description": "执行数学计算",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {"type": "string", "description": "数学表达式,如: 2 + 3 * 4"}
|
||||
},
|
||||
"required": ["expression"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**错误响应 (工具不存在):**
|
||||
|
||||
```json
|
||||
{
|
||||
"detail": "Tool not found"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 3. 健康检查
|
||||
|
||||
```http
|
||||
GET /api/v1/tools/health
|
||||
```
|
||||
|
||||
**Response:**
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "healthy",
|
||||
"timestamp": 1705315200.123,
|
||||
"tools": ["search", "calculator", "weather", "translate", "knowledge", "code_interpreter"]
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 自动测试 (Autotest)
|
||||
|
||||
### 4. 运行完整自动测试
|
||||
|
||||
```http
|
||||
POST /api/v1/tools/autotest
|
||||
```
|
||||
|
||||
**Query Parameters:**
|
||||
|
||||
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|
||||
|------|------|------|--------|------|
|
||||
| llm_model_id | string | 否 | - | LLM 模型ID |
|
||||
| asr_model_id | string | 否 | - | ASR 模型ID |
|
||||
| test_llm | boolean | 否 | true | 是否测试LLM |
|
||||
| test_asr | boolean | 否 | true | 是否测试ASR |
|
||||
|
||||
**Response:**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "abc12345",
|
||||
"started_at": 1705315200.0,
|
||||
"duration_ms": 2500,
|
||||
"tests": [
|
||||
{
|
||||
"name": "Model Existence",
|
||||
"passed": true,
|
||||
"message": "Found model: GPT-4o",
|
||||
"duration_ms": 15
|
||||
},
|
||||
{
|
||||
"name": "API Connection",
|
||||
"passed": true,
|
||||
"message": "Latency: 150ms",
|
||||
"duration_ms": 150
|
||||
},
|
||||
{
|
||||
"name": "Temperature Setting",
|
||||
"passed": true,
|
||||
"message": "temperature=0.7"
|
||||
},
|
||||
{
|
||||
"name": "Streaming Support",
|
||||
"passed": true,
|
||||
"message": "Received 15 chunks",
|
||||
"duration_ms": 800
|
||||
}
|
||||
],
|
||||
"summary": {
|
||||
"passed": 4,
|
||||
"failed": 0,
|
||||
"total": 4
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 5. 测试单个 LLM 模型
|
||||
|
||||
```http
|
||||
POST /api/v1/tools/autotest/llm/{model_id}
|
||||
```
|
||||
|
||||
**Path Parameters:**
|
||||
|
||||
| 参数 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| model_id | string | LLM 模型ID |
|
||||
|
||||
**Response:**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "llm_test_001",
|
||||
"started_at": 1705315200.0,
|
||||
"duration_ms": 1200,
|
||||
"tests": [
|
||||
{
|
||||
"name": "Model Existence",
|
||||
"passed": true,
|
||||
"message": "Found model: GPT-4o",
|
||||
"duration_ms": 10
|
||||
},
|
||||
{
|
||||
"name": "API Connection",
|
||||
"passed": true,
|
||||
"message": "Latency: 180ms",
|
||||
"duration_ms": 180
|
||||
},
|
||||
{
|
||||
"name": "Temperature Setting",
|
||||
"passed": true,
|
||||
"message": "temperature=0.7"
|
||||
},
|
||||
{
|
||||
"name": "Streaming Support",
|
||||
"passed": true,
|
||||
"message": "Received 12 chunks",
|
||||
"duration_ms": 650
|
||||
}
|
||||
],
|
||||
"summary": {
|
||||
"passed": 4,
|
||||
"failed": 0,
|
||||
"total": 4
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 6. 测试单个 ASR 模型
|
||||
|
||||
```http
|
||||
POST /api/v1/tools/autotest/asr/{model_id}
|
||||
```
|
||||
|
||||
**Path Parameters:**
|
||||
|
||||
| 参数 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| model_id | string | ASR 模型ID |
|
||||
|
||||
**Response:**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "asr_test_001",
|
||||
"started_at": 1705315200.0,
|
||||
"duration_ms": 800,
|
||||
"tests": [
|
||||
{
|
||||
"name": "Model Existence",
|
||||
"passed": true,
|
||||
"message": "Found model: Whisper-1",
|
||||
"duration_ms": 8
|
||||
},
|
||||
{
|
||||
"name": "Hotwords Config",
|
||||
"passed": true,
|
||||
"message": "Hotwords: 3 words"
|
||||
},
|
||||
{
|
||||
"name": "API Availability",
|
||||
"passed": true,
|
||||
"message": "Status: 200",
|
||||
"duration_ms": 250
|
||||
},
|
||||
{
|
||||
"name": "Language Config",
|
||||
"passed": true,
|
||||
"message": "Language: zh"
|
||||
}
|
||||
],
|
||||
"summary": {
|
||||
"passed": 4,
|
||||
"failed": 0,
|
||||
"total": 4
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 7. 发送测试消息
|
||||
|
||||
```http
|
||||
POST /api/v1/tools/test-message
|
||||
```
|
||||
|
||||
**Query Parameters:**
|
||||
|
||||
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|
||||
|------|------|------|--------|------|
|
||||
| llm_model_id | string | 是 | - | LLM 模型ID |
|
||||
| message | string | 否 | "Hello, this is a test message." | 测试消息 |
|
||||
|
||||
**Response:**
|
||||
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"reply": "Hello! This is a test reply from GPT-4o.",
|
||||
"usage": {
|
||||
"prompt_tokens": 15,
|
||||
"completion_tokens": 12,
|
||||
"total_tokens": 27
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**错误响应 (模型不存在):**
|
||||
|
||||
```json
|
||||
{
|
||||
"detail": "LLM Model not found"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 测试结果结构
|
||||
|
||||
### AutotestResult
|
||||
|
||||
```typescript
|
||||
interface AutotestResult {
|
||||
id: string; // 测试ID
|
||||
started_at: number; // 开始时间戳
|
||||
duration_ms: number; // 总耗时(毫秒)
|
||||
tests: TestCase[]; // 测试用例列表
|
||||
summary: TestSummary; // 测试摘要
|
||||
}
|
||||
|
||||
interface TestCase {
|
||||
name: string; // 测试名称
|
||||
passed: boolean; // 是否通过
|
||||
message: string; // 测试消息
|
||||
duration_ms: number; // 耗时(毫秒)
|
||||
}
|
||||
|
||||
interface TestSummary {
|
||||
passed: number; // 通过数量
|
||||
failed: number; // 失败数量
|
||||
total: number; // 总数量
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 测试项目说明
|
||||
|
||||
### LLM 模型测试项目
|
||||
|
||||
| 测试名称 | 说明 |
|
||||
|----------|------|
|
||||
| Model Existence | 检查模型是否存在于数据库 |
|
||||
| API Connection | 测试 API 连接并测量延迟 |
|
||||
| Temperature Setting | 检查温度配置 |
|
||||
| Streaming Support | 测试流式响应支持 |
|
||||
|
||||
### ASR 模型测试项目
|
||||
|
||||
| 测试名称 | 说明 |
|
||||
|----------|------|
|
||||
| Model Existence | 检查模型是否存在于数据库 |
|
||||
| Hotwords Config | 检查热词配置 |
|
||||
| API Availability | 测试 API 可用性 |
|
||||
| Language Config | 检查语言配置 |
|
||||
|
||||
---
|
||||
|
||||
## 单元测试
|
||||
|
||||
项目包含完整的单元测试,位于 `api/tests/test_tools.py`。
|
||||
|
||||
### 测试用例概览
|
||||
|
||||
| 测试类 | 说明 |
|
||||
|--------|------|
|
||||
| TestToolsAPI | 工具列表、健康检查等基础功能测试 |
|
||||
| TestAutotestAPI | 自动测试功能完整测试 |
|
||||
|
||||
### 运行测试
|
||||
|
||||
```bash
|
||||
# 运行工具相关测试
|
||||
pytest api/tests/test_tools.py -v
|
||||
|
||||
# 运行所有测试
|
||||
pytest api/tests/ -v
|
||||
```
|
||||
371
api/init_db.py
371
api/init_db.py
@@ -7,7 +7,7 @@ import sys
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from app.db import Base, engine, DATABASE_URL
|
||||
from app.models import Voice
|
||||
from app.models import Voice, Assistant, KnowledgeBase, Workflow, LLMModel, ASRModel
|
||||
|
||||
|
||||
def init_db():
|
||||
@@ -22,10 +22,379 @@ def init_db():
|
||||
print("✅ 数据库表创建完成")
|
||||
|
||||
|
||||
def init_default_data():
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db import SessionLocal
|
||||
from app.models import Voice
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 检查是否已有数据
|
||||
if db.query(Voice).count() == 0:
|
||||
# SiliconFlow CosyVoice 2.0 预设声音 (8个)
|
||||
# 参考: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech
|
||||
voices = [
|
||||
# 男声 (Male Voices)
|
||||
Voice(id="alex", name="Alex", vendor="SiliconFlow", gender="Male", language="en",
|
||||
description="Steady male voice.", is_system=True),
|
||||
Voice(id="benjamin", name="Benjamin", vendor="SiliconFlow", gender="Male", language="en",
|
||||
description="Deep male voice.", is_system=True),
|
||||
Voice(id="charles", name="Charles", vendor="SiliconFlow", gender="Male", language="en",
|
||||
description="Magnetic male voice.", is_system=True),
|
||||
Voice(id="david", name="David", vendor="SiliconFlow", gender="Male", language="en",
|
||||
description="Cheerful male voice.", is_system=True),
|
||||
# 女声 (Female Voices)
|
||||
Voice(id="anna", name="Anna", vendor="SiliconFlow", gender="Female", language="en",
|
||||
description="Steady female voice.", is_system=True),
|
||||
Voice(id="bella", name="Bella", vendor="SiliconFlow", gender="Female", language="en",
|
||||
description="Passionate female voice.", is_system=True),
|
||||
Voice(id="claire", name="Claire", vendor="SiliconFlow", gender="Female", language="en",
|
||||
description="Gentle female voice.", is_system=True),
|
||||
Voice(id="diana", name="Diana", vendor="SiliconFlow", gender="Female", language="en",
|
||||
description="Cheerful female voice.", is_system=True),
|
||||
# 中文方言 (Chinese Dialects) - 可选扩展
|
||||
Voice(id="amador", name="Amador", vendor="SiliconFlow", gender="Male", language="zh",
|
||||
description="Male voice with Spanish accent."),
|
||||
Voice(id="aelora", name="Aelora", vendor="SiliconFlow", gender="Female", language="en",
|
||||
description="Elegant female voice."),
|
||||
Voice(id="aelwin", name="Aelwin", vendor="SiliconFlow", gender="Male", language="en",
|
||||
description="Deep male voice."),
|
||||
Voice(id="blooming", name="Blooming", vendor="SiliconFlow", gender="Female", language="en",
|
||||
description="Fresh and clear female voice."),
|
||||
Voice(id="elysia", name="Elysia", vendor="SiliconFlow", gender="Female", language="en",
|
||||
description="Smooth and silky female voice."),
|
||||
Voice(id="leo", name="Leo", vendor="SiliconFlow", gender="Male", language="en",
|
||||
description="Young male voice."),
|
||||
Voice(id="lin", name="Lin", vendor="SiliconFlow", gender="Female", language="zh",
|
||||
description="Standard Chinese female voice."),
|
||||
Voice(id="rose", name="Rose", vendor="SiliconFlow", gender="Female", language="en",
|
||||
description="Soft and gentle female voice."),
|
||||
Voice(id="shao", name="Shao", vendor="SiliconFlow", gender="Male", language="zh",
|
||||
description="Deep Chinese male voice."),
|
||||
Voice(id="sky", name="Sky", vendor="SiliconFlow", gender="Male", language="en",
|
||||
description="Clear and bright male voice."),
|
||||
Voice(id="ael西山", name="Ael西山", vendor="SiliconFlow", gender="Female", language="zh",
|
||||
description="Female voice with Chinese dialect."),
|
||||
]
|
||||
for v in voices:
|
||||
db.add(v)
|
||||
db.commit()
|
||||
print("✅ 默认声音数据已初始化 (SiliconFlow CosyVoice 2.0)")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def init_default_assistants():
|
||||
"""初始化默认助手"""
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db import SessionLocal
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
if db.query(Assistant).count() == 0:
|
||||
assistants = [
|
||||
Assistant(
|
||||
id="default",
|
||||
user_id=1,
|
||||
name="AI 助手",
|
||||
call_count=0,
|
||||
opener="你好!我是AI助手,有什么可以帮你的吗?",
|
||||
prompt="你是一个友好的AI助手,请用简洁清晰的语言回答用户的问题。",
|
||||
language="zh",
|
||||
voice="anna",
|
||||
speed=1.0,
|
||||
hotwords=[],
|
||||
tools=["search", "calculator"],
|
||||
interruption_sensitivity=500,
|
||||
config_mode="platform",
|
||||
llm_model_id="deepseek-chat",
|
||||
asr_model_id="paraformer-v2",
|
||||
),
|
||||
Assistant(
|
||||
id="customer_service",
|
||||
user_id=1,
|
||||
name="客服助手",
|
||||
call_count=0,
|
||||
opener="您好,欢迎致电客服中心,请问有什么可以帮您?",
|
||||
prompt="你是一个专业的客服人员,耐心解答客户问题,提供优质的服务体验。",
|
||||
language="zh",
|
||||
voice="bella",
|
||||
speed=1.0,
|
||||
hotwords=["客服", "投诉", "咨询"],
|
||||
tools=["search"],
|
||||
interruption_sensitivity=600,
|
||||
config_mode="platform",
|
||||
),
|
||||
Assistant(
|
||||
id="english_tutor",
|
||||
user_id=1,
|
||||
name="英语导师",
|
||||
call_count=0,
|
||||
opener="Hello! I'm your English learning companion. How can I help you today?",
|
||||
prompt="You are a friendly English tutor. Help users practice English conversation and explain grammar points clearly.",
|
||||
language="en",
|
||||
voice="alex",
|
||||
speed=1.0,
|
||||
hotwords=["grammar", "vocabulary", "practice"],
|
||||
tools=[],
|
||||
interruption_sensitivity=400,
|
||||
config_mode="platform",
|
||||
),
|
||||
]
|
||||
for a in assistants:
|
||||
db.add(a)
|
||||
db.commit()
|
||||
print("✅ 默认助手数据已初始化")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def init_default_workflows():
|
||||
"""初始化默认工作流"""
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db import SessionLocal
|
||||
from datetime import datetime
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
if db.query(Workflow).count() == 0:
|
||||
now = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
|
||||
workflows = [
|
||||
Workflow(
|
||||
id="simple_conversation",
|
||||
user_id=1,
|
||||
name="简单对话",
|
||||
node_count=2,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
global_prompt="处理简单的对话流程,用户问什么答什么。",
|
||||
nodes=[
|
||||
{"id": "1", "type": "start", "position": {"x": 100, "y": 100}, "data": {"label": "开始"}},
|
||||
{"id": "2", "type": "ai_reply", "position": {"x": 300, "y": 100}, "data": {"label": "AI回复"}},
|
||||
],
|
||||
edges=[{"source": "1", "target": "2", "id": "e1-2"}],
|
||||
),
|
||||
Workflow(
|
||||
id="voice_input_flow",
|
||||
user_id=1,
|
||||
name="语音输入流程",
|
||||
node_count=4,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
global_prompt="处理语音输入的完整流程。",
|
||||
nodes=[
|
||||
{"id": "1", "type": "start", "position": {"x": 100, "y": 100}, "data": {"label": "开始"}},
|
||||
{"id": "2", "type": "asr", "position": {"x": 250, "y": 100}, "data": {"label": "语音识别"}},
|
||||
{"id": "3", "type": "llm", "position": {"x": 400, "y": 100}, "data": {"label": "LLM处理"}},
|
||||
{"id": "4", "type": "tts", "position": {"x": 550, "y": 100}, "data": {"label": "语音合成"}},
|
||||
],
|
||||
edges=[
|
||||
{"source": "1", "target": "2", "id": "e1-2"},
|
||||
{"source": "2", "target": "3", "id": "e2-3"},
|
||||
{"source": "3", "target": "4", "id": "e3-4"},
|
||||
],
|
||||
),
|
||||
]
|
||||
for w in workflows:
|
||||
db.add(w)
|
||||
db.commit()
|
||||
print("✅ 默认工作流数据已初始化")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def init_default_knowledge_bases():
|
||||
"""初始化默认知识库"""
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db import SessionLocal
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
if db.query(KnowledgeBase).count() == 0:
|
||||
kb = KnowledgeBase(
|
||||
id="default_kb",
|
||||
user_id=1,
|
||||
name="默认知识库",
|
||||
description="系统默认知识库,用于存储常见问题解答。",
|
||||
embedding_model="text-embedding-3-small",
|
||||
chunk_size=500,
|
||||
chunk_overlap=50,
|
||||
doc_count=0,
|
||||
chunk_count=0,
|
||||
status="active",
|
||||
)
|
||||
db.add(kb)
|
||||
db.commit()
|
||||
print("✅ 默认知识库已初始化")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def init_default_llm_models():
|
||||
"""初始化默认LLM模型"""
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db import SessionLocal
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
if db.query(LLMModel).count() == 0:
|
||||
llm_models = [
|
||||
LLMModel(
|
||||
id="deepseek-chat",
|
||||
user_id=1,
|
||||
name="DeepSeek Chat",
|
||||
vendor="SiliconFlow",
|
||||
type="text",
|
||||
base_url="https://api.deepseek.com",
|
||||
api_key="YOUR_API_KEY", # 用户需替换
|
||||
model_name="deepseek-chat",
|
||||
temperature=0.7,
|
||||
context_length=4096,
|
||||
enabled=True,
|
||||
),
|
||||
LLMModel(
|
||||
id="deepseek-reasoner",
|
||||
user_id=1,
|
||||
name="DeepSeek Reasoner",
|
||||
vendor="SiliconFlow",
|
||||
type="text",
|
||||
base_url="https://api.deepseek.com",
|
||||
api_key="YOUR_API_KEY",
|
||||
model_name="deepseek-reasoner",
|
||||
temperature=0.7,
|
||||
context_length=4096,
|
||||
enabled=True,
|
||||
),
|
||||
LLMModel(
|
||||
id="gpt-4o",
|
||||
user_id=1,
|
||||
name="GPT-4o",
|
||||
vendor="OpenAI",
|
||||
type="text",
|
||||
base_url="https://api.openai.com/v1",
|
||||
api_key="YOUR_API_KEY",
|
||||
model_name="gpt-4o",
|
||||
temperature=0.7,
|
||||
context_length=16384,
|
||||
enabled=True,
|
||||
),
|
||||
LLMModel(
|
||||
id="glm-4",
|
||||
user_id=1,
|
||||
name="GLM-4",
|
||||
vendor="ZhipuAI",
|
||||
type="text",
|
||||
base_url="https://open.bigmodel.cn/api/paas/v4",
|
||||
api_key="YOUR_API_KEY",
|
||||
model_name="glm-4",
|
||||
temperature=0.7,
|
||||
context_length=8192,
|
||||
enabled=True,
|
||||
),
|
||||
LLMModel(
|
||||
id="text-embedding-3-small",
|
||||
user_id=1,
|
||||
name="Embedding 3 Small",
|
||||
vendor="OpenAI",
|
||||
type="embedding",
|
||||
base_url="https://api.openai.com/v1",
|
||||
api_key="YOUR_API_KEY",
|
||||
model_name="text-embedding-3-small",
|
||||
enabled=True,
|
||||
),
|
||||
]
|
||||
for m in llm_models:
|
||||
db.add(m)
|
||||
db.commit()
|
||||
print("✅ 默认LLM模型已初始化")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def init_default_asr_models():
|
||||
"""初始化默认ASR模型"""
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db import SessionLocal
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
if db.query(ASRModel).count() == 0:
|
||||
asr_models = [
|
||||
ASRModel(
|
||||
id="paraformer-v2",
|
||||
user_id=1,
|
||||
name="Paraformer V2",
|
||||
vendor="SiliconFlow",
|
||||
language="zh",
|
||||
base_url="https://api.siliconflow.cn/v1",
|
||||
api_key="YOUR_API_KEY",
|
||||
model_name="paraformer-v2",
|
||||
hotwords=["人工智能", "机器学习"],
|
||||
enable_punctuation=True,
|
||||
enable_normalization=True,
|
||||
enabled=True,
|
||||
),
|
||||
ASRModel(
|
||||
id="paraformer-en",
|
||||
user_id=1,
|
||||
name="Paraformer English",
|
||||
vendor="SiliconFlow",
|
||||
language="en",
|
||||
base_url="https://api.siliconflow.cn/v1",
|
||||
api_key="YOUR_API_KEY",
|
||||
model_name="paraformer-en",
|
||||
hotwords=[],
|
||||
enable_punctuation=True,
|
||||
enable_normalization=True,
|
||||
enabled=True,
|
||||
),
|
||||
ASRModel(
|
||||
id="whisper-1",
|
||||
user_id=1,
|
||||
name="Whisper",
|
||||
vendor="OpenAI",
|
||||
language="Multi-lingual",
|
||||
base_url="https://api.openai.com/v1",
|
||||
api_key="YOUR_API_KEY",
|
||||
model_name="whisper-1",
|
||||
hotwords=[],
|
||||
enable_punctuation=True,
|
||||
enable_normalization=True,
|
||||
enabled=True,
|
||||
),
|
||||
ASRModel(
|
||||
id="sensevoice",
|
||||
user_id=1,
|
||||
name="SenseVoice",
|
||||
vendor="SiliconFlow",
|
||||
language="Multi-lingual",
|
||||
base_url="https://api.siliconflow.cn/v1",
|
||||
api_key="YOUR_API_KEY",
|
||||
model_name="sensevoice",
|
||||
hotwords=[],
|
||||
enable_punctuation=True,
|
||||
enable_normalization=True,
|
||||
enabled=True,
|
||||
),
|
||||
]
|
||||
for m in asr_models:
|
||||
db.add(m)
|
||||
db.commit()
|
||||
print("✅ 默认ASR模型已初始化")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 确保 data 目录存在
|
||||
data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data")
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
init_db()
|
||||
init_default_data()
|
||||
init_default_assistants()
|
||||
init_default_workflows()
|
||||
init_default_knowledge_bases()
|
||||
init_default_llm_models()
|
||||
init_default_asr_models()
|
||||
print("🎉 数据库初始化完成!")
|
||||
|
||||
@@ -100,3 +100,38 @@ def sample_call_record_data():
|
||||
"assistant_id": None,
|
||||
"source": "debug"
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_llm_model_data():
|
||||
"""Sample LLM model data for testing"""
|
||||
return {
|
||||
"id": "test-llm-001",
|
||||
"name": "Test LLM Model",
|
||||
"vendor": "TestVendor",
|
||||
"type": "text",
|
||||
"base_url": "https://api.test.com/v1",
|
||||
"api_key": "test-api-key",
|
||||
"model_name": "test-model",
|
||||
"temperature": 0.7,
|
||||
"context_length": 4096,
|
||||
"enabled": True
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_asr_model_data():
|
||||
"""Sample ASR model data for testing"""
|
||||
return {
|
||||
"id": "test-asr-001",
|
||||
"name": "Test ASR Model",
|
||||
"vendor": "TestVendor",
|
||||
"language": "zh",
|
||||
"base_url": "https://api.test.com/v1",
|
||||
"api_key": "test-api-key",
|
||||
"model_name": "paraformer-v2",
|
||||
"hotwords": ["测试", "语音"],
|
||||
"enable_punctuation": True,
|
||||
"enable_normalization": True,
|
||||
"enabled": True
|
||||
}
|
||||
|
||||
289
api/tests/test_asr.py
Normal file
289
api/tests/test_asr.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""Tests for ASR Model API endpoints"""
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
class TestASRModelAPI:
|
||||
"""Test cases for ASR Model endpoints"""
|
||||
|
||||
def test_get_asr_models_empty(self, client):
|
||||
"""Test getting ASR models when database is empty"""
|
||||
response = client.get("/api/asr")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "total" in data
|
||||
assert "list" in data
|
||||
assert data["total"] == 0
|
||||
|
||||
def test_create_asr_model(self, client, sample_asr_model_data):
|
||||
"""Test creating a new ASR model"""
|
||||
response = client.post("/api/asr", json=sample_asr_model_data)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == sample_asr_model_data["name"]
|
||||
assert data["vendor"] == sample_asr_model_data["vendor"]
|
||||
assert data["language"] == sample_asr_model_data["language"]
|
||||
assert "id" in data
|
||||
|
||||
def test_create_asr_model_minimal(self, client):
|
||||
"""Test creating an ASR model with minimal required data"""
|
||||
data = {
|
||||
"name": "Minimal ASR",
|
||||
"vendor": "Test",
|
||||
"language": "zh",
|
||||
"base_url": "https://api.test.com",
|
||||
"api_key": "test-key"
|
||||
}
|
||||
response = client.post("/api/asr", json=data)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "Minimal ASR"
|
||||
|
||||
def test_get_asr_model_by_id(self, client, sample_asr_model_data):
|
||||
"""Test getting a specific ASR model by ID"""
|
||||
# Create first
|
||||
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Get by ID
|
||||
response = client.get(f"/api/asr/{model_id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == model_id
|
||||
assert data["name"] == sample_asr_model_data["name"]
|
||||
|
||||
def test_get_asr_model_not_found(self, client):
|
||||
"""Test getting a non-existent ASR model"""
|
||||
response = client.get("/api/asr/non-existent-id")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_update_asr_model(self, client, sample_asr_model_data):
|
||||
"""Test updating an ASR model"""
|
||||
# Create first
|
||||
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Update
|
||||
update_data = {
|
||||
"name": "Updated ASR Model",
|
||||
"language": "en",
|
||||
"enable_punctuation": False
|
||||
}
|
||||
response = client.put(f"/api/asr/{model_id}", json=update_data)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "Updated ASR Model"
|
||||
assert data["language"] == "en"
|
||||
assert data["enable_punctuation"] == False
|
||||
|
||||
def test_delete_asr_model(self, client, sample_asr_model_data):
|
||||
"""Test deleting an ASR model"""
|
||||
# Create first
|
||||
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Delete
|
||||
response = client.delete(f"/api/asr/{model_id}")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify deleted
|
||||
get_response = client.get(f"/api/asr/{model_id}")
|
||||
assert get_response.status_code == 404
|
||||
|
||||
def test_list_asr_models_with_pagination(self, client, sample_asr_model_data):
|
||||
"""Test listing ASR models with pagination"""
|
||||
# Create multiple models
|
||||
for i in range(3):
|
||||
data = sample_asr_model_data.copy()
|
||||
data["id"] = f"test-asr-{i}"
|
||||
data["name"] = f"ASR Model {i}"
|
||||
client.post("/api/asr", json=data)
|
||||
|
||||
# Test pagination
|
||||
response = client.get("/api/asr?page=1&limit=2")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 3
|
||||
assert len(data["list"]) == 2
|
||||
|
||||
def test_filter_asr_models_by_language(self, client, sample_asr_model_data):
|
||||
"""Test filtering ASR models by language"""
|
||||
# Create models with different languages
|
||||
for i, lang in enumerate(["zh", "en", "Multi-lingual"]):
|
||||
data = sample_asr_model_data.copy()
|
||||
data["id"] = f"test-asr-{lang}"
|
||||
data["name"] = f"ASR {lang}"
|
||||
data["language"] = lang
|
||||
client.post("/api/asr", json=data)
|
||||
|
||||
# Filter by language
|
||||
response = client.get("/api/asr?language=zh")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] >= 1
|
||||
for model in data["list"]:
|
||||
assert model["language"] == "zh"
|
||||
|
||||
def test_filter_asr_models_by_enabled(self, client, sample_asr_model_data):
|
||||
"""Test filtering ASR models by enabled status"""
|
||||
# Create enabled and disabled models
|
||||
data = sample_asr_model_data.copy()
|
||||
data["id"] = "test-asr-enabled"
|
||||
data["name"] = "Enabled ASR"
|
||||
data["enabled"] = True
|
||||
client.post("/api/asr", json=data)
|
||||
|
||||
data["id"] = "test-asr-disabled"
|
||||
data["name"] = "Disabled ASR"
|
||||
data["enabled"] = False
|
||||
client.post("/api/asr", json=data)
|
||||
|
||||
# Filter by enabled
|
||||
response = client.get("/api/asr?enabled=true")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
for model in data["list"]:
|
||||
assert model["enabled"] == True
|
||||
|
||||
def test_create_asr_model_with_hotwords(self, client):
|
||||
"""Test creating an ASR model with hotwords"""
|
||||
data = {
|
||||
"id": "asr-hotwords",
|
||||
"name": "ASR with Hotwords",
|
||||
"vendor": "SiliconFlow",
|
||||
"language": "zh",
|
||||
"base_url": "https://api.siliconflow.cn/v1",
|
||||
"api_key": "test-key",
|
||||
"model_name": "paraformer-v2",
|
||||
"hotwords": ["你好", "查询", "帮助"],
|
||||
"enable_punctuation": True,
|
||||
"enable_normalization": True
|
||||
}
|
||||
response = client.post("/api/asr", json=data)
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
assert result["hotwords"] == ["你好", "查询", "帮助"]
|
||||
|
||||
def test_create_asr_model_with_all_fields(self, client):
|
||||
"""Test creating an ASR model with all fields"""
|
||||
data = {
|
||||
"id": "full-asr",
|
||||
"name": "Full ASR Model",
|
||||
"vendor": "SiliconFlow",
|
||||
"language": "zh",
|
||||
"base_url": "https://api.siliconflow.cn/v1",
|
||||
"api_key": "sk-test",
|
||||
"model_name": "paraformer-v2",
|
||||
"hotwords": ["测试"],
|
||||
"enable_punctuation": True,
|
||||
"enable_normalization": True,
|
||||
"enabled": True
|
||||
}
|
||||
response = client.post("/api/asr", json=data)
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
assert result["name"] == "Full ASR Model"
|
||||
assert result["enable_punctuation"] == True
|
||||
assert result["enable_normalization"] == True
|
||||
|
||||
@patch('httpx.Client')
|
||||
def test_test_asr_model_siliconflow(self, mock_client_class, client, sample_asr_model_data):
|
||||
"""Test testing an ASR model with SiliconFlow vendor"""
|
||||
# Create model first
|
||||
sample_asr_model_data["vendor"] = "SiliconFlow"
|
||||
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Mock the HTTP response
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"results": [{"transcript": "测试文本", "language": "zh"}]
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_client.get.return_value = mock_response
|
||||
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch('app.routers.asr.httpx.Client', return_value=mock_client):
|
||||
response = client.post(f"/api/asr/{model_id}/test")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] == True
|
||||
|
||||
@patch('httpx.Client')
|
||||
def test_test_asr_model_openai(self, mock_client_class, client, sample_asr_model_data):
|
||||
"""Test testing an ASR model with OpenAI vendor"""
|
||||
# Create model with OpenAI vendor
|
||||
sample_asr_model_data["vendor"] = "OpenAI"
|
||||
sample_asr_model_data["id"] = "test-asr-openai"
|
||||
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Mock the HTTP response
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"text": "Test transcript"}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_client.get.return_value = mock_response
|
||||
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch('app.routers.asr.httpx.Client', return_value=mock_client):
|
||||
response = client.post(f"/api/asr/{model_id}/test")
|
||||
assert response.status_code == 200
|
||||
|
||||
@patch('httpx.Client')
|
||||
def test_test_asr_model_failure(self, mock_client_class, client, sample_asr_model_data):
|
||||
"""Test testing an ASR model with failed connection"""
|
||||
# Create model first
|
||||
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Mock HTTP error
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = "Unauthorized"
|
||||
mock_response.raise_for_status = MagicMock(side_effect=Exception("401 Unauthorized"))
|
||||
mock_client.get.return_value = mock_response
|
||||
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch('app.routers.asr.httpx.Client', return_value=mock_client):
|
||||
response = client.post(f"/api/asr/{model_id}/test")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] == False
|
||||
|
||||
def test_different_asr_languages(self, client):
|
||||
"""Test creating ASR models with different languages"""
|
||||
for lang in ["zh", "en", "Multi-lingual"]:
|
||||
data = {
|
||||
"id": f"asr-lang-{lang}",
|
||||
"name": f"ASR {lang}",
|
||||
"vendor": "SiliconFlow",
|
||||
"language": lang,
|
||||
"base_url": "https://api.siliconflow.cn/v1",
|
||||
"api_key": "test-key"
|
||||
}
|
||||
response = client.post("/api/asr", json=data)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["language"] == lang
|
||||
|
||||
def test_different_asr_vendors(self, client):
|
||||
"""Test creating ASR models with different vendors"""
|
||||
vendors = ["SiliconFlow", "OpenAI", "Azure"]
|
||||
for vendor in vendors:
|
||||
data = {
|
||||
"id": f"asr-vendor-{vendor.lower()}",
|
||||
"name": f"ASR {vendor}",
|
||||
"vendor": vendor,
|
||||
"language": "zh",
|
||||
"base_url": f"https://api.{vendor.lower()}.com/v1",
|
||||
"api_key": "test-key"
|
||||
}
|
||||
response = client.post("/api/asr", json=data)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["vendor"] == vendor
|
||||
246
api/tests/test_llm.py
Normal file
246
api/tests/test_llm.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""Tests for LLM Model API endpoints"""
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
class TestLLMModelAPI:
|
||||
"""Test cases for LLM Model endpoints"""
|
||||
|
||||
def test_get_llm_models_empty(self, client):
|
||||
"""Test getting LLM models when database is empty"""
|
||||
response = client.get("/api/llm")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "total" in data
|
||||
assert "list" in data
|
||||
assert data["total"] == 0
|
||||
|
||||
def test_create_llm_model(self, client, sample_llm_model_data):
|
||||
"""Test creating a new LLM model"""
|
||||
response = client.post("/api/llm", json=sample_llm_model_data)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == sample_llm_model_data["name"]
|
||||
assert data["vendor"] == sample_llm_model_data["vendor"]
|
||||
assert data["type"] == sample_llm_model_data["type"]
|
||||
assert data["base_url"] == sample_llm_model_data["base_url"]
|
||||
assert "id" in data
|
||||
|
||||
def test_create_llm_model_minimal(self, client):
|
||||
"""Test creating an LLM model with minimal required data"""
|
||||
data = {
|
||||
"name": "Minimal LLM",
|
||||
"vendor": "Test",
|
||||
"type": "text",
|
||||
"base_url": "https://api.test.com",
|
||||
"api_key": "test-key"
|
||||
}
|
||||
response = client.post("/api/llm", json=data)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "Minimal LLM"
|
||||
|
||||
def test_get_llm_model_by_id(self, client, sample_llm_model_data):
|
||||
"""Test getting a specific LLM model by ID"""
|
||||
# Create first
|
||||
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Get by ID
|
||||
response = client.get(f"/api/llm/{model_id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == model_id
|
||||
assert data["name"] == sample_llm_model_data["name"]
|
||||
|
||||
def test_get_llm_model_not_found(self, client):
|
||||
"""Test getting a non-existent LLM model"""
|
||||
response = client.get("/api/llm/non-existent-id")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_update_llm_model(self, client, sample_llm_model_data):
|
||||
"""Test updating an LLM model"""
|
||||
# Create first
|
||||
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Update
|
||||
update_data = {
|
||||
"name": "Updated LLM Model",
|
||||
"temperature": 0.5,
|
||||
"context_length": 8192
|
||||
}
|
||||
response = client.put(f"/api/llm/{model_id}", json=update_data)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "Updated LLM Model"
|
||||
assert data["temperature"] == 0.5
|
||||
assert data["context_length"] == 8192
|
||||
|
||||
def test_delete_llm_model(self, client, sample_llm_model_data):
|
||||
"""Test deleting an LLM model"""
|
||||
# Create first
|
||||
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Delete
|
||||
response = client.delete(f"/api/llm/{model_id}")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify deleted
|
||||
get_response = client.get(f"/api/llm/{model_id}")
|
||||
assert get_response.status_code == 404
|
||||
|
||||
def test_list_llm_models_with_pagination(self, client, sample_llm_model_data):
|
||||
"""Test listing LLM models with pagination"""
|
||||
# Create multiple models
|
||||
for i in range(3):
|
||||
data = sample_llm_model_data.copy()
|
||||
data["id"] = f"test-llm-{i}"
|
||||
data["name"] = f"LLM Model {i}"
|
||||
client.post("/api/llm", json=data)
|
||||
|
||||
# Test pagination
|
||||
response = client.get("/api/llm?page=1&limit=2")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 3
|
||||
assert len(data["list"]) == 2
|
||||
|
||||
def test_filter_llm_models_by_type(self, client, sample_llm_model_data):
|
||||
"""Test filtering LLM models by type"""
|
||||
# Create models with different types
|
||||
for i, model_type in enumerate(["text", "embedding", "rerank"]):
|
||||
data = sample_llm_model_data.copy()
|
||||
data["id"] = f"test-llm-{model_type}"
|
||||
data["name"] = f"LLM {model_type}"
|
||||
data["type"] = model_type
|
||||
client.post("/api/llm", json=data)
|
||||
|
||||
# Filter by type
|
||||
response = client.get("/api/llm?model_type=text")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] >= 1
|
||||
for model in data["list"]:
|
||||
assert model["type"] == "text"
|
||||
|
||||
def test_filter_llm_models_by_enabled(self, client, sample_llm_model_data):
|
||||
"""Test filtering LLM models by enabled status"""
|
||||
# Create enabled and disabled models
|
||||
data = sample_llm_model_data.copy()
|
||||
data["id"] = "test-llm-enabled"
|
||||
data["name"] = "Enabled LLM"
|
||||
data["enabled"] = True
|
||||
client.post("/api/llm", json=data)
|
||||
|
||||
data["id"] = "test-llm-disabled"
|
||||
data["name"] = "Disabled LLM"
|
||||
data["enabled"] = False
|
||||
client.post("/api/llm", json=data)
|
||||
|
||||
# Filter by enabled
|
||||
response = client.get("/api/llm?enabled=true")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
for model in data["list"]:
|
||||
assert model["enabled"] == True
|
||||
|
||||
def test_create_llm_model_with_all_fields(self, client):
|
||||
"""Test creating an LLM model with all fields"""
|
||||
data = {
|
||||
"id": "full-llm",
|
||||
"name": "Full LLM Model",
|
||||
"vendor": "OpenAI",
|
||||
"type": "text",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-test",
|
||||
"model_name": "gpt-4",
|
||||
"temperature": 0.8,
|
||||
"context_length": 16384,
|
||||
"enabled": True
|
||||
}
|
||||
response = client.post("/api/llm", json=data)
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
assert result["name"] == "Full LLM Model"
|
||||
assert result["temperature"] == 0.8
|
||||
assert result["context_length"] == 16384
|
||||
|
||||
@patch('httpx.Client')
|
||||
def test_test_llm_model_success(self, mock_client_class, client, sample_llm_model_data):
|
||||
"""Test testing an LLM model with successful connection"""
|
||||
# Create model first
|
||||
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Mock the HTTP response
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "OK"}}]
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch('app.routers.llm.httpx.Client', return_value=mock_client):
|
||||
response = client.post(f"/api/llm/{model_id}/test")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] == True
|
||||
|
||||
@patch('httpx.Client')
|
||||
def test_test_llm_model_failure(self, mock_client_class, client, sample_llm_model_data):
|
||||
"""Test testing an LLM model with failed connection"""
|
||||
# Create model first
|
||||
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Mock HTTP error
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = "Unauthorized"
|
||||
mock_response.raise_for_status = MagicMock(side_effect=Exception("401 Unauthorized"))
|
||||
mock_client.post.return_value = mock_response
|
||||
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch('app.routers.llm.httpx.Client', return_value=mock_client):
|
||||
response = client.post(f"/api/llm/{model_id}/test")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] == False
|
||||
|
||||
def test_different_llm_vendors(self, client):
|
||||
"""Test creating LLM models with different vendors"""
|
||||
vendors = ["OpenAI", "SiliconFlow", "ZhipuAI", "Anthropic"]
|
||||
for vendor in vendors:
|
||||
data = {
|
||||
"id": f"test-{vendor.lower()}",
|
||||
"name": f"Test {vendor}",
|
||||
"vendor": vendor,
|
||||
"type": "text",
|
||||
"base_url": f"https://api.{vendor.lower()}.com/v1",
|
||||
"api_key": "test-key"
|
||||
}
|
||||
response = client.post("/api/llm", json=data)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["vendor"] == vendor
|
||||
|
||||
def test_embedding_llm_model(self, client):
|
||||
"""Test creating an embedding LLM model"""
|
||||
data = {
|
||||
"id": "embedding-test",
|
||||
"name": "Embedding Model",
|
||||
"vendor": "OpenAI",
|
||||
"type": "embedding",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "test-key",
|
||||
"model_name": "text-embedding-3-small"
|
||||
}
|
||||
response = client.post("/api/llm", json=data)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["type"] == "embedding"
|
||||
267
api/tests/test_tools.py
Normal file
267
api/tests/test_tools.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""Tests for Tools & Autotest API endpoints"""
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
class TestToolsAPI:
|
||||
"""Test cases for Tools endpoints"""
|
||||
|
||||
def test_list_available_tools(self, client):
|
||||
"""Test listing all available tools"""
|
||||
response = client.get("/api/tools/list")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "tools" in data
|
||||
# Check for expected tools
|
||||
tools = data["tools"]
|
||||
assert "search" in tools
|
||||
assert "calculator" in tools
|
||||
assert "weather" in tools
|
||||
|
||||
def test_get_tool_detail(self, client):
|
||||
"""Test getting a specific tool's details"""
|
||||
response = client.get("/api/tools/list/search")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "网络搜索"
|
||||
assert "parameters" in data
|
||||
|
||||
def test_get_tool_detail_not_found(self, client):
|
||||
"""Test getting a non-existent tool"""
|
||||
response = client.get("/api/tools/list/non-existent-tool")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_health_check(self, client):
|
||||
"""Test health check endpoint"""
|
||||
response = client.get("/api/tools/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert "timestamp" in data
|
||||
assert "tools" in data
|
||||
|
||||
|
||||
class TestAutotestAPI:
|
||||
"""Test cases for Autotest endpoints"""
|
||||
|
||||
def test_autotest_no_models(self, client):
|
||||
"""Test autotest without specifying model IDs"""
|
||||
response = client.post("/api/tools/autotest")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "id" in data
|
||||
assert "tests" in data
|
||||
assert "summary" in data
|
||||
# Should have test failures since no models provided
|
||||
assert data["summary"]["total"] > 0
|
||||
|
||||
def test_autotest_with_llm_model(self, client, sample_llm_model_data):
|
||||
"""Test autotest with an LLM model"""
|
||||
# Create an LLM model first
|
||||
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Run autotest
|
||||
response = client.post(f"/api/tools/autotest?llm_model_id={model_id}&test_asr=false")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "tests" in data
|
||||
assert "summary" in data
|
||||
|
||||
def test_autotest_with_asr_model(self, client, sample_asr_model_data):
|
||||
"""Test autotest with an ASR model"""
|
||||
# Create an ASR model first
|
||||
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Run autotest
|
||||
response = client.post(f"/api/tools/autotest?asr_model_id={model_id}&test_llm=false")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "tests" in data
|
||||
assert "summary" in data
|
||||
|
||||
def test_autotest_with_both_models(self, client, sample_llm_model_data, sample_asr_model_data):
|
||||
"""Test autotest with both LLM and ASR models"""
|
||||
# Create models
|
||||
llm_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||
llm_id = llm_response.json()["id"]
|
||||
|
||||
asr_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||
asr_id = asr_response.json()["id"]
|
||||
|
||||
# Run autotest
|
||||
response = client.post(
|
||||
f"/api/tools/autotest?llm_model_id={llm_id}&asr_model_id={asr_id}"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "tests" in data
|
||||
assert "summary" in data
|
||||
|
||||
@patch('httpx.Client')
|
||||
def test_autotest_llm_model_success(self, mock_client_class, client, sample_llm_model_data):
|
||||
"""Test autotest for a specific LLM model with successful connection"""
|
||||
# Create an LLM model first
|
||||
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Mock the HTTP response for successful connection
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "OK"}}]
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.iter_bytes = MagicMock(return_value=[b'chunk1', b'chunk2'])
|
||||
mock_client.post.return_value = mock_response
|
||||
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch('app.routers.tools.httpx.Client', return_value=mock_client):
|
||||
response = client.post(f"/api/tools/autotest/llm/{model_id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "tests" in data
|
||||
assert "summary" in data
|
||||
|
||||
@patch('httpx.Client')
|
||||
def test_autotest_asr_model_success(self, mock_client_class, client, sample_asr_model_data):
|
||||
"""Test autotest for a specific ASR model with successful connection"""
|
||||
# Create an ASR model first
|
||||
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Mock the HTTP response for successful connection
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_client.get.return_value = mock_response
|
||||
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch('app.routers.tools.httpx.Client', return_value=mock_client):
|
||||
response = client.post(f"/api/tools/autotest/asr/{model_id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "tests" in data
|
||||
assert "summary" in data
|
||||
|
||||
def test_autotest_llm_model_not_found(self, client):
|
||||
"""Test autotest for a non-existent LLM model"""
|
||||
response = client.post("/api/tools/autotest/llm/non-existent-id")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Should have a failure test
|
||||
assert any(not t["passed"] for t in data["tests"])
|
||||
|
||||
def test_autotest_asr_model_not_found(self, client):
|
||||
"""Test autotest for a non-existent ASR model"""
|
||||
response = client.post("/api/tools/autotest/asr/non-existent-id")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Should have a failure test
|
||||
assert any(not t["passed"] for t in data["tests"])
|
||||
|
||||
@patch('httpx.Client')
|
||||
def test_test_message_success(self, mock_client_class, client, sample_llm_model_data):
|
||||
"""Test sending a test message to an LLM model"""
|
||||
# Create an LLM model first
|
||||
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Mock the HTTP response
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "Hello! This is a test reply."}}],
|
||||
"usage": {"total_tokens": 10}
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch('app.routers.tools.httpx.Client', return_value=mock_client):
|
||||
response = client.post(
|
||||
f"/api/tools/test-message?llm_model_id={model_id}",
|
||||
json={"message": "Hello!"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] == True
|
||||
assert "reply" in data
|
||||
|
||||
def test_test_message_model_not_found(self, client):
|
||||
"""Test sending a test message to a non-existent model"""
|
||||
response = client.post(
|
||||
"/api/tools/test-message?llm_model_id=non-existent",
|
||||
json={"message": "Hello!"}
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_autotest_result_structure(self, client):
|
||||
"""Test that autotest results have the correct structure"""
|
||||
response = client.post("/api/tools/autotest")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Check required fields
|
||||
assert "id" in data
|
||||
assert "started_at" in data
|
||||
assert "duration_ms" in data
|
||||
assert "tests" in data
|
||||
assert "summary" in data
|
||||
|
||||
# Check summary structure
|
||||
assert "passed" in data["summary"]
|
||||
assert "failed" in data["summary"]
|
||||
assert "total" in data["summary"]
|
||||
|
||||
# Check test structure
|
||||
if data["tests"]:
|
||||
test = data["tests"][0]
|
||||
assert "name" in test
|
||||
assert "passed" in test
|
||||
assert "message" in test
|
||||
assert "duration_ms" in test
|
||||
|
||||
def test_tools_have_required_fields(self, client):
|
||||
"""Test that all tools have required fields"""
|
||||
response = client.get("/api/tools/list")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
for tool_id, tool in data["tools"].items():
|
||||
assert "name" in tool
|
||||
assert "description" in tool
|
||||
assert "parameters" in tool
|
||||
|
||||
# Check parameters structure
|
||||
params = tool["parameters"]
|
||||
assert "type" in params
|
||||
assert "properties" in params
|
||||
|
||||
def test_calculator_tool_parameters(self, client):
|
||||
"""Test calculator tool has correct parameters"""
|
||||
response = client.get("/api/tools/list/calculator")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["name"] == "计算器"
|
||||
assert "expression" in data["parameters"]["properties"]
|
||||
assert "required" in data["parameters"]
|
||||
assert "expression" in data["parameters"]["required"]
|
||||
|
||||
def test_translate_tool_parameters(self, client):
|
||||
"""Test translate tool has correct parameters"""
|
||||
response = client.get("/api/tools/list/translate")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["name"] == "翻译"
|
||||
assert "text" in data["parameters"]["properties"]
|
||||
assert "target_lang" in data["parameters"]["properties"]
|
||||
Reference in New Issue
Block a user