Update backend api
This commit is contained in:
@@ -3,9 +3,15 @@ from fastapi import APIRouter
|
||||
from . import assistants
|
||||
from . import history
|
||||
from . import knowledge
|
||||
from . import llm
|
||||
from . import asr
|
||||
from . import tools
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
router.include_router(assistants.router)
|
||||
router.include_router(history.router)
|
||||
router.include_router(knowledge.router)
|
||||
router.include_router(llm.router)
|
||||
router.include_router(asr.router)
|
||||
router.include_router(tools.router)
|
||||
|
||||
268
api/app/routers/asr.py
Normal file
268
api/app/routers/asr.py
Normal file
@@ -0,0 +1,268 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
import httpx
|
||||
import time
|
||||
import base64
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from ..db import get_db
|
||||
from ..models import ASRModel
|
||||
from ..schemas import (
|
||||
ASRModelCreate, ASRModelUpdate, ASRModelOut,
|
||||
ASRTestRequest, ASRTestResponse, ListResponse
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/asr", tags=["ASR Models"])
|
||||
|
||||
|
||||
# ============ ASR Models CRUD ============
|
||||
@router.get("", response_model=ListResponse)
|
||||
def list_asr_models(
|
||||
language: Optional[str] = None,
|
||||
enabled: Optional[bool] = None,
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取ASR模型列表"""
|
||||
query = db.query(ASRModel)
|
||||
|
||||
if language:
|
||||
query = query.filter(ASRModel.language == language)
|
||||
if enabled is not None:
|
||||
query = query.filter(ASRModel.enabled == enabled)
|
||||
|
||||
total = query.count()
|
||||
models = query.order_by(ASRModel.created_at.desc()) \
|
||||
.offset((page-1)*limit).limit(limit).all()
|
||||
|
||||
return {"total": total, "page": page, "limit": limit, "list": models}
|
||||
|
||||
|
||||
@router.get("/{id}", response_model=ASRModelOut)
|
||||
def get_asr_model(id: str, db: Session = Depends(get_db)):
|
||||
"""获取单个ASR模型详情"""
|
||||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||||
return model
|
||||
|
||||
|
||||
@router.post("", response_model=ASRModelOut)
|
||||
def create_asr_model(data: ASRModelCreate, db: Session = Depends(get_db)):
|
||||
"""创建ASR模型"""
|
||||
asr_model = ASRModel(
|
||||
id=data.id or str(uuid.uuid4())[:8],
|
||||
user_id=1, # 默认用户
|
||||
name=data.name,
|
||||
vendor=data.vendor,
|
||||
language=data.language,
|
||||
base_url=data.base_url,
|
||||
api_key=data.api_key,
|
||||
model_name=data.model_name,
|
||||
hotwords=data.hotwords,
|
||||
enable_punctuation=data.enable_punctuation,
|
||||
enable_normalization=data.enable_normalization,
|
||||
enabled=data.enabled,
|
||||
)
|
||||
db.add(asr_model)
|
||||
db.commit()
|
||||
db.refresh(asr_model)
|
||||
return asr_model
|
||||
|
||||
|
||||
@router.put("/{id}", response_model=ASRModelOut)
|
||||
def update_asr_model(id: str, data: ASRModelUpdate, db: Session = Depends(get_db)):
|
||||
"""更新ASR模型"""
|
||||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(model, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
return model
|
||||
|
||||
|
||||
@router.delete("/{id}")
|
||||
def delete_asr_model(id: str, db: Session = Depends(get_db)):
|
||||
"""删除ASR模型"""
|
||||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||||
db.delete(model)
|
||||
db.commit()
|
||||
return {"message": "Deleted successfully"}
|
||||
|
||||
|
||||
@router.post("/{id}/test", response_model=ASRTestResponse)
|
||||
def test_asr_model(
|
||||
id: str,
|
||||
request: Optional[ASRTestRequest] = None,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""测试ASR模型"""
|
||||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 根据不同的厂商构造不同的请求
|
||||
if model.vendor.lower() in ["siliconflow", "paraformer"]:
|
||||
# SiliconFlow/Paraformer 格式
|
||||
payload = {
|
||||
"model": model.model_name or "paraformer-v2",
|
||||
"input": {},
|
||||
"parameters": {
|
||||
"hotwords": " ".join(model.hotwords) if model.hotwords else "",
|
||||
"enable_punctuation": model.enable_punctuation,
|
||||
"enable_normalization": model.enable_normalization,
|
||||
}
|
||||
}
|
||||
|
||||
# 如果有音频数据
|
||||
if request and request.audio_data:
|
||||
payload["input"]["file_urls"] = []
|
||||
elif request and request.audio_url:
|
||||
payload["input"]["url"] = request.audio_url
|
||||
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
response = client.post(
|
||||
f"{model.base_url}/asr",
|
||||
json=payload,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
elif model.vendor.lower() == "openai":
|
||||
# OpenAI Whisper 格式
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
|
||||
# 准备文件
|
||||
files = {}
|
||||
if request and request.audio_data:
|
||||
audio_bytes = base64.b64decode(request.audio_data)
|
||||
files = {"file": ("audio.wav", audio_bytes, "audio/wav")}
|
||||
data = {"model": model.model_name or "whisper-1"}
|
||||
elif request and request.audio_url:
|
||||
files = {"file": ("audio.wav", httpx.get(request.audio_url).content, "audio/wav")}
|
||||
data = {"model": model.model_name or "whisper-1"}
|
||||
else:
|
||||
return ASRTestResponse(
|
||||
success=False,
|
||||
error="No audio data or URL provided"
|
||||
)
|
||||
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
response = client.post(
|
||||
f"{model.base_url}/audio/transcriptions",
|
||||
files=files,
|
||||
data=data,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
result = {"results": [{"transcript": result.get("text", "")}]}
|
||||
|
||||
else:
|
||||
# 通用格式(可根据需要扩展)
|
||||
return ASRTestResponse(
|
||||
success=False,
|
||||
message=f"Unsupported vendor: {model.vendor}"
|
||||
)
|
||||
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 解析结果
|
||||
if result_data := result.get("results", [{}])[0]:
|
||||
transcript = result_data.get("transcript", "")
|
||||
return ASRTestResponse(
|
||||
success=True,
|
||||
transcript=transcript,
|
||||
language=result_data.get("language", model.language),
|
||||
confidence=result_data.get("confidence"),
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
return ASRTestResponse(
|
||||
success=False,
|
||||
message="No transcript in response",
|
||||
latency_ms=latency_ms
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
return ASRTestResponse(
|
||||
success=False,
|
||||
error=f"HTTP Error: {e.response.status_code} - {e.response.text[:200]}"
|
||||
)
|
||||
except Exception as e:
|
||||
return ASRTestResponse(
|
||||
success=False,
|
||||
error=str(e)[:200]
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{id}/transcribe")
|
||||
def transcribe_audio(
|
||||
id: str,
|
||||
audio_url: Optional[str] = None,
|
||||
audio_data: Optional[str] = None,
|
||||
hotwords: Optional[List[str]] = None,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""转写音频"""
|
||||
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||||
|
||||
try:
|
||||
payload = {
|
||||
"model": model.model_name or "paraformer-v2",
|
||||
"input": {},
|
||||
"parameters": {
|
||||
"hotwords": " ".join(hotwords or model.hotwords or []),
|
||||
"enable_punctuation": model.enable_punctuation,
|
||||
"enable_normalization": model.enable_normalization,
|
||||
}
|
||||
}
|
||||
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
|
||||
if audio_url:
|
||||
payload["input"]["url"] = audio_url
|
||||
elif audio_data:
|
||||
payload["input"]["file_urls"] = []
|
||||
|
||||
with httpx.Client(timeout=120.0) as client:
|
||||
response = client.post(
|
||||
f"{model.base_url}/asr",
|
||||
json=payload,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
|
||||
if result_data := result.get("results", [{}])[0]:
|
||||
return {
|
||||
"success": True,
|
||||
"transcript": result_data.get("transcript", ""),
|
||||
"language": result_data.get("language", model.language),
|
||||
"confidence": result_data.get("confidence"),
|
||||
}
|
||||
|
||||
return {"success": False, "error": "No transcript in response"}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -1,6 +1,6 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
@@ -8,7 +8,7 @@ from ..db import get_db
|
||||
from ..models import Assistant, Voice, Workflow
|
||||
from ..schemas import (
|
||||
AssistantCreate, AssistantUpdate, AssistantOut,
|
||||
VoiceOut,
|
||||
VoiceCreate, VoiceUpdate, VoiceOut,
|
||||
WorkflowCreate, WorkflowUpdate, WorkflowOut
|
||||
)
|
||||
|
||||
@@ -16,11 +16,88 @@ router = APIRouter()
|
||||
|
||||
|
||||
# ============ Voices ============
|
||||
@router.get("/voices", response_model=List[VoiceOut])
|
||||
def list_voices(db: Session = Depends(get_db)):
|
||||
@router.get("/voices")
|
||||
def list_voices(
|
||||
vendor: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
gender: Optional[str] = None,
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取声音库列表"""
|
||||
voices = db.query(Voice).all()
|
||||
return voices
|
||||
query = db.query(Voice)
|
||||
if vendor:
|
||||
query = query.filter(Voice.vendor == vendor)
|
||||
if language:
|
||||
query = query.filter(Voice.language == language)
|
||||
if gender:
|
||||
query = query.filter(Voice.gender == gender)
|
||||
|
||||
total = query.count()
|
||||
voices = query.order_by(Voice.created_at.desc()) \
|
||||
.offset((page-1)*limit).limit(limit).all()
|
||||
return {"total": total, "page": page, "limit": limit, "list": voices}
|
||||
|
||||
|
||||
@router.post("/voices", response_model=VoiceOut)
|
||||
def create_voice(data: VoiceCreate, db: Session = Depends(get_db)):
|
||||
"""创建声音"""
|
||||
voice = Voice(
|
||||
id=data.id or str(uuid.uuid4())[:8],
|
||||
user_id=1,
|
||||
name=data.name,
|
||||
vendor=data.vendor,
|
||||
gender=data.gender,
|
||||
language=data.language,
|
||||
description=data.description,
|
||||
model=data.model,
|
||||
voice_key=data.voice_key,
|
||||
speed=data.speed,
|
||||
gain=data.gain,
|
||||
pitch=data.pitch,
|
||||
enabled=data.enabled,
|
||||
)
|
||||
db.add(voice)
|
||||
db.commit()
|
||||
db.refresh(voice)
|
||||
return voice
|
||||
|
||||
|
||||
@router.get("/voices/{id}", response_model=VoiceOut)
|
||||
def get_voice(id: str, db: Session = Depends(get_db)):
|
||||
"""获取单个声音详情"""
|
||||
voice = db.query(Voice).filter(Voice.id == id).first()
|
||||
if not voice:
|
||||
raise HTTPException(status_code=404, detail="Voice not found")
|
||||
return voice
|
||||
|
||||
|
||||
@router.put("/voices/{id}", response_model=VoiceOut)
|
||||
def update_voice(id: str, data: VoiceUpdate, db: Session = Depends(get_db)):
|
||||
"""更新声音"""
|
||||
voice = db.query(Voice).filter(Voice.id == id).first()
|
||||
if not voice:
|
||||
raise HTTPException(status_code=404, detail="Voice not found")
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(voice, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(voice)
|
||||
return voice
|
||||
|
||||
|
||||
@router.delete("/voices/{id}")
|
||||
def delete_voice(id: str, db: Session = Depends(get_db)):
|
||||
"""删除声音"""
|
||||
voice = db.query(Voice).filter(Voice.id == id).first()
|
||||
if not voice:
|
||||
raise HTTPException(status_code=404, detail="Voice not found")
|
||||
db.delete(voice)
|
||||
db.commit()
|
||||
return {"message": "Deleted successfully"}
|
||||
|
||||
|
||||
# ============ Assistants ============
|
||||
@@ -79,11 +156,11 @@ def update_assistant(id: str, data: AssistantUpdate, db: Session = Depends(get_d
|
||||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||||
if not assistant:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(assistant, field, value)
|
||||
|
||||
|
||||
assistant.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
db.refresh(assistant)
|
||||
@@ -103,10 +180,17 @@ def delete_assistant(id: str, db: Session = Depends(get_db)):
|
||||
|
||||
# ============ Workflows ============
|
||||
@router.get("/workflows", response_model=List[WorkflowOut])
|
||||
def list_workflows(db: Session = Depends(get_db)):
|
||||
def list_workflows(
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取工作流列表"""
|
||||
workflows = db.query(Workflow).all()
|
||||
return workflows
|
||||
query = db.query(Workflow)
|
||||
total = query.count()
|
||||
workflows = query.order_by(Workflow.created_at.desc()) \
|
||||
.offset((page-1)*limit).limit(limit).all()
|
||||
return {"total": total, "page": page, "limit": limit, "list": workflows}
|
||||
|
||||
|
||||
@router.post("/workflows", response_model=WorkflowOut)
|
||||
@@ -129,17 +213,26 @@ def create_workflow(data: WorkflowCreate, db: Session = Depends(get_db)):
|
||||
return workflow
|
||||
|
||||
|
||||
@router.get("/workflows/{id}", response_model=WorkflowOut)
|
||||
def get_workflow(id: str, db: Session = Depends(get_db)):
|
||||
"""获取单个工作流"""
|
||||
workflow = db.query(Workflow).filter(Workflow.id == id).first()
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
return workflow
|
||||
|
||||
|
||||
@router.put("/workflows/{id}", response_model=WorkflowOut)
|
||||
def update_workflow(id: str, data: WorkflowUpdate, db: Session = Depends(get_db)):
|
||||
"""更新工作流"""
|
||||
workflow = db.query(Workflow).filter(Workflow.id == id).first()
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(workflow, field, value)
|
||||
|
||||
|
||||
workflow.updated_at = datetime.utcnow().isoformat()
|
||||
db.commit()
|
||||
db.refresh(workflow)
|
||||
|
||||
206
api/app/routers/llm.py
Normal file
206
api/app/routers/llm.py
Normal file
@@ -0,0 +1,206 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
import httpx
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from ..db import get_db
|
||||
from ..models import LLMModel
|
||||
from ..schemas import (
|
||||
LLMModelCreate, LLMModelUpdate, LLMModelOut,
|
||||
LLMModelTestResponse, ListResponse
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/llm", tags=["LLM Models"])
|
||||
|
||||
|
||||
# ============ LLM Models CRUD ============
|
||||
@router.get("", response_model=ListResponse)
|
||||
def list_llm_models(
|
||||
model_type: Optional[str] = None,
|
||||
enabled: Optional[bool] = None,
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取LLM模型列表"""
|
||||
query = db.query(LLMModel)
|
||||
|
||||
if model_type:
|
||||
query = query.filter(LLMModel.type == model_type)
|
||||
if enabled is not None:
|
||||
query = query.filter(LLMModel.enabled == enabled)
|
||||
|
||||
total = query.count()
|
||||
models = query.order_by(LLMModel.created_at.desc()) \
|
||||
.offset((page-1)*limit).limit(limit).all()
|
||||
|
||||
return {"total": total, "page": page, "limit": limit, "list": models}
|
||||
|
||||
|
||||
@router.get("/{id}", response_model=LLMModelOut)
|
||||
def get_llm_model(id: str, db: Session = Depends(get_db)):
|
||||
"""获取单个LLM模型详情"""
|
||||
model = db.query(LLMModel).filter(LLMModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="LLM Model not found")
|
||||
return model
|
||||
|
||||
|
||||
@router.post("", response_model=LLMModelOut)
|
||||
def create_llm_model(data: LLMModelCreate, db: Session = Depends(get_db)):
|
||||
"""创建LLM模型"""
|
||||
llm_model = LLMModel(
|
||||
id=data.id or str(uuid.uuid4())[:8],
|
||||
user_id=1, # 默认用户
|
||||
name=data.name,
|
||||
vendor=data.vendor,
|
||||
type=data.type.value if hasattr(data.type, 'value') else data.type,
|
||||
base_url=data.base_url,
|
||||
api_key=data.api_key,
|
||||
model_name=data.model_name,
|
||||
temperature=data.temperature,
|
||||
context_length=data.context_length,
|
||||
enabled=data.enabled,
|
||||
)
|
||||
db.add(llm_model)
|
||||
db.commit()
|
||||
db.refresh(llm_model)
|
||||
return llm_model
|
||||
|
||||
|
||||
@router.put("/{id}", response_model=LLMModelOut)
|
||||
def update_llm_model(id: str, data: LLMModelUpdate, db: Session = Depends(get_db)):
|
||||
"""更新LLM模型"""
|
||||
model = db.query(LLMModel).filter(LLMModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="LLM Model not found")
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(model, field, value)
|
||||
|
||||
model.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
return model
|
||||
|
||||
|
||||
@router.delete("/{id}")
|
||||
def delete_llm_model(id: str, db: Session = Depends(get_db)):
|
||||
"""删除LLM模型"""
|
||||
model = db.query(LLMModel).filter(LLMModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="LLM Model not found")
|
||||
db.delete(model)
|
||||
db.commit()
|
||||
return {"message": "Deleted successfully"}
|
||||
|
||||
|
||||
@router.post("/{id}/test", response_model=LLMModelTestResponse)
|
||||
def test_llm_model(id: str, db: Session = Depends(get_db)):
|
||||
"""测试LLM模型连接"""
|
||||
model = db.query(LLMModel).filter(LLMModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="LLM Model not found")
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
# 构造测试请求
|
||||
test_messages = [{"role": "user", "content": "Hello, please reply with 'OK'."}]
|
||||
|
||||
payload = {
|
||||
"model": model.model_name or "gpt-3.5-turbo",
|
||||
"messages": test_messages,
|
||||
"max_tokens": 10,
|
||||
"temperature": 0.1,
|
||||
}
|
||||
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
|
||||
with httpx.Client(timeout=30.0) as client:
|
||||
response = client.post(
|
||||
f"{model.base_url}/chat/completions",
|
||||
json=payload,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
result = response.json()
|
||||
|
||||
if result.get("choices"):
|
||||
return LLMModelTestResponse(
|
||||
success=True,
|
||||
latency_ms=latency_ms,
|
||||
message="Connection successful"
|
||||
)
|
||||
else:
|
||||
return LLMModelTestResponse(
|
||||
success=False,
|
||||
latency_ms=latency_ms,
|
||||
message="Unexpected response format"
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
return LLMModelTestResponse(
|
||||
success=False,
|
||||
message=f"HTTP Error: {e.response.status_code} - {e.response.text[:200]}"
|
||||
)
|
||||
except Exception as e:
|
||||
return LLMModelTestResponse(
|
||||
success=False,
|
||||
message=str(e)[:200]
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{id}/chat")
|
||||
def chat_with_llm(
|
||||
id: str,
|
||||
message: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""与LLM模型对话"""
|
||||
model = db.query(LLMModel).filter(LLMModel.id == id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="LLM Model not found")
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": message})
|
||||
|
||||
payload = {
|
||||
"model": model.model_name or "gpt-3.5-turbo",
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens or 1000,
|
||||
"temperature": temperature if temperature is not None else model.temperature or 0.7,
|
||||
}
|
||||
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
response = client.post(
|
||||
f"{model.base_url}/chat/completions",
|
||||
json=payload,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
if choice := result.get("choices", [{}])[0]:
|
||||
return {
|
||||
"success": True,
|
||||
"reply": choice.get("message", {}).get("content", ""),
|
||||
"usage": result.get("usage", {})
|
||||
}
|
||||
return {"success": False, "reply": "", "error": "No response"}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
379
api/app/routers/tools.py
Normal file
379
api/app/routers/tools.py
Normal file
@@ -0,0 +1,379 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional, Dict, Any
|
||||
import time
|
||||
import uuid
|
||||
import httpx
|
||||
|
||||
from ..db import get_db
|
||||
from ..models import LLMModel, ASRModel
|
||||
|
||||
router = APIRouter(prefix="/tools", tags=["Tools & Autotest"])
|
||||
|
||||
|
||||
# ============ Available Tools ============
|
||||
TOOL_REGISTRY = {
|
||||
"search": {
|
||||
"name": "网络搜索",
|
||||
"description": "搜索互联网获取最新信息",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "搜索关键词"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
},
|
||||
"calculator": {
|
||||
"name": "计算器",
|
||||
"description": "执行数学计算",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {"type": "string", "description": "数学表达式,如: 2 + 3 * 4"}
|
||||
},
|
||||
"required": ["expression"]
|
||||
}
|
||||
},
|
||||
"weather": {
|
||||
"name": "天气查询",
|
||||
"description": "查询指定城市的天气",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "城市名称"}
|
||||
},
|
||||
"required": ["city"]
|
||||
}
|
||||
},
|
||||
"translate": {
|
||||
"name": "翻译",
|
||||
"description": "翻译文本到指定语言",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string", "description": "要翻译的文本"},
|
||||
"target_lang": {"type": "string", "description": "目标语言,如: en, ja, ko"}
|
||||
},
|
||||
"required": ["text", "target_lang"]
|
||||
}
|
||||
},
|
||||
"knowledge": {
|
||||
"name": "知识库查询",
|
||||
"description": "从知识库中检索相关信息",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "查询内容"},
|
||||
"kb_id": {"type": "string", "description": "知识库ID"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
},
|
||||
"code_interpreter": {
|
||||
"name": "代码执行",
|
||||
"description": "安全地执行Python代码",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {"type": "string", "description": "要执行的Python代码"}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get("/list")
|
||||
def list_available_tools():
|
||||
"""获取可用的工具列表"""
|
||||
return {"tools": TOOL_REGISTRY}
|
||||
|
||||
|
||||
@router.get("/list/{tool_id}")
|
||||
def get_tool_detail(tool_id: str):
|
||||
"""获取工具详情"""
|
||||
if tool_id not in TOOL_REGISTRY:
|
||||
raise HTTPException(status_code=404, detail="Tool not found")
|
||||
return TOOL_REGISTRY[tool_id]
|
||||
|
||||
|
||||
# ============ Autotest ============
|
||||
class AutotestResult:
|
||||
"""自动测试结果"""
|
||||
|
||||
def __init__(self):
|
||||
self.id = str(uuid.uuid4())[:8]
|
||||
self.started_at = time.time()
|
||||
self.tests = []
|
||||
self.summary = {"passed": 0, "failed": 0, "total": 0}
|
||||
|
||||
def add_test(self, name: str, passed: bool, message: str = "", duration_ms: int = 0):
|
||||
self.tests.append({
|
||||
"name": name,
|
||||
"passed": passed,
|
||||
"message": message,
|
||||
"duration_ms": duration_ms
|
||||
})
|
||||
if passed:
|
||||
self.summary["passed"] += 1
|
||||
else:
|
||||
self.summary["failed"] += 1
|
||||
self.summary["total"] += 1
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"started_at": self.started_at,
|
||||
"duration_ms": int((time.time() - self.started_at) * 1000),
|
||||
"tests": self.tests,
|
||||
"summary": self.summary
|
||||
}
|
||||
|
||||
|
||||
@router.post("/autotest")
|
||||
def run_autotest(
|
||||
llm_model_id: Optional[str] = None,
|
||||
asr_model_id: Optional[str] = None,
|
||||
test_llm: bool = True,
|
||||
test_asr: bool = True,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""运行自动测试"""
|
||||
result = AutotestResult()
|
||||
|
||||
# 测试 LLM 模型
|
||||
if test_llm and llm_model_id:
|
||||
_test_llm_model(db, llm_model_id, result)
|
||||
|
||||
# 测试 ASR 模型
|
||||
if test_asr and asr_model_id:
|
||||
_test_asr_model(db, asr_model_id, result)
|
||||
|
||||
# 测试 TTS 功能(需要时可添加)
|
||||
if test_llm and not llm_model_id:
|
||||
result.add_test(
|
||||
"LLM Model Check",
|
||||
False,
|
||||
"No LLM model ID provided"
|
||||
)
|
||||
|
||||
if test_asr and not asr_model_id:
|
||||
result.add_test(
|
||||
"ASR Model Check",
|
||||
False,
|
||||
"No ASR model ID provided"
|
||||
)
|
||||
|
||||
return result.to_dict()
|
||||
|
||||
|
||||
@router.post("/autotest/llm/{model_id}")
|
||||
def autotest_llm_model(model_id: str, db: Session = Depends(get_db)):
|
||||
"""测试单个LLM模型"""
|
||||
result = AutotestResult()
|
||||
_test_llm_model(db, model_id, result)
|
||||
return result.to_dict()
|
||||
|
||||
|
||||
@router.post("/autotest/asr/{model_id}")
|
||||
def autotest_asr_model(model_id: str, db: Session = Depends(get_db)):
|
||||
"""测试单个ASR模型"""
|
||||
result = AutotestResult()
|
||||
_test_asr_model(db, model_id, result)
|
||||
return result.to_dict()
|
||||
|
||||
|
||||
def _test_llm_model(db: Session, model_id: str, result: AutotestResult):
|
||||
"""内部方法:测试LLM模型"""
|
||||
start_time = time.time()
|
||||
|
||||
# 1. 检查模型是否存在
|
||||
model = db.query(LLMModel).filter(LLMModel.id == model_id).first()
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
if not model:
|
||||
result.add_test("Model Existence", False, f"Model {model_id} not found", duration_ms)
|
||||
return
|
||||
|
||||
result.add_test("Model Existence", True, f"Found model: {model.name}", duration_ms)
|
||||
|
||||
# 2. 测试连接
|
||||
test_start = time.time()
|
||||
try:
|
||||
test_messages = [{"role": "user", "content": "Reply with 'OK'."}]
|
||||
payload = {
|
||||
"model": model.model_name or "gpt-3.5-turbo",
|
||||
"messages": test_messages,
|
||||
"max_tokens": 10,
|
||||
"temperature": 0.1,
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
|
||||
with httpx.Client(timeout=30.0) as client:
|
||||
response = client.post(
|
||||
f"{model.base_url}/chat/completions",
|
||||
json=payload,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result_text = response.json()
|
||||
latency_ms = int((time.time() - test_start) * 1000)
|
||||
|
||||
if result_text.get("choices"):
|
||||
result.add_test("API Connection", True, f"Latency: {latency_ms}ms", latency_ms)
|
||||
else:
|
||||
result.add_test("API Connection", False, "Empty response", latency_ms)
|
||||
|
||||
except Exception as e:
|
||||
latency_ms = int((time.time() - test_start) * 1000)
|
||||
result.add_test("API Connection", False, str(e)[:200], latency_ms)
|
||||
|
||||
# 3. 检查模型配置
|
||||
if model.temperature is not None:
|
||||
result.add_test("Temperature Setting", True, f"temperature={model.temperature}")
|
||||
else:
|
||||
result.add_test("Temperature Setting", True, "Using default")
|
||||
|
||||
# 4. 测试流式响应(可选)
|
||||
if model.type == "text":
|
||||
test_start = time.time()
|
||||
try:
|
||||
with httpx.Client(timeout=30.0) as client:
|
||||
with client.stream(
|
||||
"POST",
|
||||
f"{model.base_url}/chat/completions",
|
||||
json={
|
||||
"model": model.model_name or "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Count from 1 to 3."}],
|
||||
"stream": True,
|
||||
},
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
chunk_count = 0
|
||||
for _ in response.iter_bytes():
|
||||
chunk_count += 1
|
||||
|
||||
latency_ms = int((time.time() - test_start) * 1000)
|
||||
result.add_test("Streaming Support", True, f"Received {chunk_count} chunks", latency_ms)
|
||||
except Exception as e:
|
||||
latency_ms = int((time.time() - test_start) * 1000)
|
||||
result.add_test("Streaming Support", False, str(e)[:200], latency_ms)
|
||||
|
||||
|
||||
def _test_asr_model(db: Session, model_id: str, result: AutotestResult):
|
||||
"""内部方法:测试ASR模型"""
|
||||
start_time = time.time()
|
||||
|
||||
# 1. 检查模型是否存在
|
||||
model = db.query(ASRModel).filter(ASRModel.id == model_id).first()
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
if not model:
|
||||
result.add_test("Model Existence", False, f"Model {model_id} not found", duration_ms)
|
||||
return
|
||||
|
||||
result.add_test("Model Existence", True, f"Found model: {model.name}", duration_ms)
|
||||
|
||||
# 2. 测试配置
|
||||
if model.hotwords:
|
||||
result.add_test("Hotwords Config", True, f"Hotwords: {len(model.hotwords)} words")
|
||||
else:
|
||||
result.add_test("Hotwords Config", True, "No hotwords configured")
|
||||
|
||||
# 3. 测试API可用性
|
||||
test_start = time.time()
|
||||
try:
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
|
||||
with httpx.Client(timeout=30.0) as client:
|
||||
if model.vendor.lower() in ["siliconflow", "paraformer"]:
|
||||
response = client.get(
|
||||
f"{model.base_url}/asr",
|
||||
headers=headers
|
||||
)
|
||||
elif model.vendor.lower() == "openai":
|
||||
response = client.get(
|
||||
f"{model.base_url}/audio/models",
|
||||
headers=headers
|
||||
)
|
||||
else:
|
||||
# 通用健康检查
|
||||
response = client.get(
|
||||
f"{model.base_url}/health",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
latency_ms = int((time.time() - test_start) * 1000)
|
||||
|
||||
if response.status_code in [200, 405]: # 405 = method not allowed but endpoint exists
|
||||
result.add_test("API Availability", True, f"Status: {response.status_code}", latency_ms)
|
||||
else:
|
||||
result.add_test("API Availability", False, f"Status: {response.status_code}", latency_ms)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
latency_ms = int((time.time() - test_start) * 1000)
|
||||
result.add_test("API Availability", False, "Connection timeout", latency_ms)
|
||||
except Exception as e:
|
||||
latency_ms = int((time.time() - test_start) * 1000)
|
||||
result.add_test("API Availability", False, str(e)[:200], latency_ms)
|
||||
|
||||
# 4. 检查语言配置
|
||||
if model.language in ["zh", "en", "Multi-lingual"]:
|
||||
result.add_test("Language Config", True, f"Language: {model.language}")
|
||||
else:
|
||||
result.add_test("Language Config", False, f"Unknown language: {model.language}")
|
||||
|
||||
|
||||
# ============ Quick Health Check ============
|
||||
@router.get("/health")
|
||||
def health_check():
|
||||
"""快速健康检查"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"timestamp": time.time(),
|
||||
"tools": list(TOOL_REGISTRY.keys())
|
||||
}
|
||||
|
||||
|
||||
@router.post("/test-message")
|
||||
def send_test_message(
|
||||
llm_model_id: str,
|
||||
message: str = "Hello, this is a test message.",
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""发送测试消息"""
|
||||
model = db.query(LLMModel).filter(LLMModel.id == llm_model_id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="LLM Model not found")
|
||||
|
||||
try:
|
||||
payload = {
|
||||
"model": model.model_name or "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": message}],
|
||||
"max_tokens": 500,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
response = client.post(
|
||||
f"{model.base_url}/chat/completions",
|
||||
json=payload,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
reply = result.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"reply": reply,
|
||||
"usage": result.get("usage", {})
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
Reference in New Issue
Block a user