Update backend api
This commit is contained in:
379
api/app/routers/tools.py
Normal file
379
api/app/routers/tools.py
Normal file
@@ -0,0 +1,379 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional, Dict, Any
|
||||
import time
|
||||
import uuid
|
||||
import httpx
|
||||
|
||||
from ..db import get_db
|
||||
from ..models import LLMModel, ASRModel
|
||||
|
||||
router = APIRouter(prefix="/tools", tags=["Tools & Autotest"])
|
||||
|
||||
|
||||
# ============ Available Tools ============
|
||||
TOOL_REGISTRY = {
|
||||
"search": {
|
||||
"name": "网络搜索",
|
||||
"description": "搜索互联网获取最新信息",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "搜索关键词"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
},
|
||||
"calculator": {
|
||||
"name": "计算器",
|
||||
"description": "执行数学计算",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {"type": "string", "description": "数学表达式,如: 2 + 3 * 4"}
|
||||
},
|
||||
"required": ["expression"]
|
||||
}
|
||||
},
|
||||
"weather": {
|
||||
"name": "天气查询",
|
||||
"description": "查询指定城市的天气",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "城市名称"}
|
||||
},
|
||||
"required": ["city"]
|
||||
}
|
||||
},
|
||||
"translate": {
|
||||
"name": "翻译",
|
||||
"description": "翻译文本到指定语言",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string", "description": "要翻译的文本"},
|
||||
"target_lang": {"type": "string", "description": "目标语言,如: en, ja, ko"}
|
||||
},
|
||||
"required": ["text", "target_lang"]
|
||||
}
|
||||
},
|
||||
"knowledge": {
|
||||
"name": "知识库查询",
|
||||
"description": "从知识库中检索相关信息",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "查询内容"},
|
||||
"kb_id": {"type": "string", "description": "知识库ID"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
},
|
||||
"code_interpreter": {
|
||||
"name": "代码执行",
|
||||
"description": "安全地执行Python代码",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {"type": "string", "description": "要执行的Python代码"}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get("/list")
|
||||
def list_available_tools():
|
||||
"""获取可用的工具列表"""
|
||||
return {"tools": TOOL_REGISTRY}
|
||||
|
||||
|
||||
@router.get("/list/{tool_id}")
|
||||
def get_tool_detail(tool_id: str):
|
||||
"""获取工具详情"""
|
||||
if tool_id not in TOOL_REGISTRY:
|
||||
raise HTTPException(status_code=404, detail="Tool not found")
|
||||
return TOOL_REGISTRY[tool_id]
|
||||
|
||||
|
||||
# ============ Autotest ============
|
||||
class AutotestResult:
|
||||
"""自动测试结果"""
|
||||
|
||||
def __init__(self):
|
||||
self.id = str(uuid.uuid4())[:8]
|
||||
self.started_at = time.time()
|
||||
self.tests = []
|
||||
self.summary = {"passed": 0, "failed": 0, "total": 0}
|
||||
|
||||
def add_test(self, name: str, passed: bool, message: str = "", duration_ms: int = 0):
|
||||
self.tests.append({
|
||||
"name": name,
|
||||
"passed": passed,
|
||||
"message": message,
|
||||
"duration_ms": duration_ms
|
||||
})
|
||||
if passed:
|
||||
self.summary["passed"] += 1
|
||||
else:
|
||||
self.summary["failed"] += 1
|
||||
self.summary["total"] += 1
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"started_at": self.started_at,
|
||||
"duration_ms": int((time.time() - self.started_at) * 1000),
|
||||
"tests": self.tests,
|
||||
"summary": self.summary
|
||||
}
|
||||
|
||||
|
||||
@router.post("/autotest")
|
||||
def run_autotest(
|
||||
llm_model_id: Optional[str] = None,
|
||||
asr_model_id: Optional[str] = None,
|
||||
test_llm: bool = True,
|
||||
test_asr: bool = True,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""运行自动测试"""
|
||||
result = AutotestResult()
|
||||
|
||||
# 测试 LLM 模型
|
||||
if test_llm and llm_model_id:
|
||||
_test_llm_model(db, llm_model_id, result)
|
||||
|
||||
# 测试 ASR 模型
|
||||
if test_asr and asr_model_id:
|
||||
_test_asr_model(db, asr_model_id, result)
|
||||
|
||||
# 测试 TTS 功能(需要时可添加)
|
||||
if test_llm and not llm_model_id:
|
||||
result.add_test(
|
||||
"LLM Model Check",
|
||||
False,
|
||||
"No LLM model ID provided"
|
||||
)
|
||||
|
||||
if test_asr and not asr_model_id:
|
||||
result.add_test(
|
||||
"ASR Model Check",
|
||||
False,
|
||||
"No ASR model ID provided"
|
||||
)
|
||||
|
||||
return result.to_dict()
|
||||
|
||||
|
||||
@router.post("/autotest/llm/{model_id}")
|
||||
def autotest_llm_model(model_id: str, db: Session = Depends(get_db)):
|
||||
"""测试单个LLM模型"""
|
||||
result = AutotestResult()
|
||||
_test_llm_model(db, model_id, result)
|
||||
return result.to_dict()
|
||||
|
||||
|
||||
@router.post("/autotest/asr/{model_id}")
|
||||
def autotest_asr_model(model_id: str, db: Session = Depends(get_db)):
|
||||
"""测试单个ASR模型"""
|
||||
result = AutotestResult()
|
||||
_test_asr_model(db, model_id, result)
|
||||
return result.to_dict()
|
||||
|
||||
|
||||
def _test_llm_model(db: Session, model_id: str, result: AutotestResult):
|
||||
"""内部方法:测试LLM模型"""
|
||||
start_time = time.time()
|
||||
|
||||
# 1. 检查模型是否存在
|
||||
model = db.query(LLMModel).filter(LLMModel.id == model_id).first()
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
if not model:
|
||||
result.add_test("Model Existence", False, f"Model {model_id} not found", duration_ms)
|
||||
return
|
||||
|
||||
result.add_test("Model Existence", True, f"Found model: {model.name}", duration_ms)
|
||||
|
||||
# 2. 测试连接
|
||||
test_start = time.time()
|
||||
try:
|
||||
test_messages = [{"role": "user", "content": "Reply with 'OK'."}]
|
||||
payload = {
|
||||
"model": model.model_name or "gpt-3.5-turbo",
|
||||
"messages": test_messages,
|
||||
"max_tokens": 10,
|
||||
"temperature": 0.1,
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
|
||||
with httpx.Client(timeout=30.0) as client:
|
||||
response = client.post(
|
||||
f"{model.base_url}/chat/completions",
|
||||
json=payload,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result_text = response.json()
|
||||
latency_ms = int((time.time() - test_start) * 1000)
|
||||
|
||||
if result_text.get("choices"):
|
||||
result.add_test("API Connection", True, f"Latency: {latency_ms}ms", latency_ms)
|
||||
else:
|
||||
result.add_test("API Connection", False, "Empty response", latency_ms)
|
||||
|
||||
except Exception as e:
|
||||
latency_ms = int((time.time() - test_start) * 1000)
|
||||
result.add_test("API Connection", False, str(e)[:200], latency_ms)
|
||||
|
||||
# 3. 检查模型配置
|
||||
if model.temperature is not None:
|
||||
result.add_test("Temperature Setting", True, f"temperature={model.temperature}")
|
||||
else:
|
||||
result.add_test("Temperature Setting", True, "Using default")
|
||||
|
||||
# 4. 测试流式响应(可选)
|
||||
if model.type == "text":
|
||||
test_start = time.time()
|
||||
try:
|
||||
with httpx.Client(timeout=30.0) as client:
|
||||
with client.stream(
|
||||
"POST",
|
||||
f"{model.base_url}/chat/completions",
|
||||
json={
|
||||
"model": model.model_name or "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Count from 1 to 3."}],
|
||||
"stream": True,
|
||||
},
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
chunk_count = 0
|
||||
for _ in response.iter_bytes():
|
||||
chunk_count += 1
|
||||
|
||||
latency_ms = int((time.time() - test_start) * 1000)
|
||||
result.add_test("Streaming Support", True, f"Received {chunk_count} chunks", latency_ms)
|
||||
except Exception as e:
|
||||
latency_ms = int((time.time() - test_start) * 1000)
|
||||
result.add_test("Streaming Support", False, str(e)[:200], latency_ms)
|
||||
|
||||
|
||||
def _test_asr_model(db: Session, model_id: str, result: AutotestResult):
|
||||
"""内部方法:测试ASR模型"""
|
||||
start_time = time.time()
|
||||
|
||||
# 1. 检查模型是否存在
|
||||
model = db.query(ASRModel).filter(ASRModel.id == model_id).first()
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
if not model:
|
||||
result.add_test("Model Existence", False, f"Model {model_id} not found", duration_ms)
|
||||
return
|
||||
|
||||
result.add_test("Model Existence", True, f"Found model: {model.name}", duration_ms)
|
||||
|
||||
# 2. 测试配置
|
||||
if model.hotwords:
|
||||
result.add_test("Hotwords Config", True, f"Hotwords: {len(model.hotwords)} words")
|
||||
else:
|
||||
result.add_test("Hotwords Config", True, "No hotwords configured")
|
||||
|
||||
# 3. 测试API可用性
|
||||
test_start = time.time()
|
||||
try:
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
|
||||
with httpx.Client(timeout=30.0) as client:
|
||||
if model.vendor.lower() in ["siliconflow", "paraformer"]:
|
||||
response = client.get(
|
||||
f"{model.base_url}/asr",
|
||||
headers=headers
|
||||
)
|
||||
elif model.vendor.lower() == "openai":
|
||||
response = client.get(
|
||||
f"{model.base_url}/audio/models",
|
||||
headers=headers
|
||||
)
|
||||
else:
|
||||
# 通用健康检查
|
||||
response = client.get(
|
||||
f"{model.base_url}/health",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
latency_ms = int((time.time() - test_start) * 1000)
|
||||
|
||||
if response.status_code in [200, 405]: # 405 = method not allowed but endpoint exists
|
||||
result.add_test("API Availability", True, f"Status: {response.status_code}", latency_ms)
|
||||
else:
|
||||
result.add_test("API Availability", False, f"Status: {response.status_code}", latency_ms)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
latency_ms = int((time.time() - test_start) * 1000)
|
||||
result.add_test("API Availability", False, "Connection timeout", latency_ms)
|
||||
except Exception as e:
|
||||
latency_ms = int((time.time() - test_start) * 1000)
|
||||
result.add_test("API Availability", False, str(e)[:200], latency_ms)
|
||||
|
||||
# 4. 检查语言配置
|
||||
if model.language in ["zh", "en", "Multi-lingual"]:
|
||||
result.add_test("Language Config", True, f"Language: {model.language}")
|
||||
else:
|
||||
result.add_test("Language Config", False, f"Unknown language: {model.language}")
|
||||
|
||||
|
||||
# ============ Quick Health Check ============
|
||||
@router.get("/health")
|
||||
def health_check():
|
||||
"""快速健康检查"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"timestamp": time.time(),
|
||||
"tools": list(TOOL_REGISTRY.keys())
|
||||
}
|
||||
|
||||
|
||||
@router.post("/test-message")
|
||||
def send_test_message(
|
||||
llm_model_id: str,
|
||||
message: str = "Hello, this is a test message.",
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""发送测试消息"""
|
||||
model = db.query(LLMModel).filter(LLMModel.id == llm_model_id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="LLM Model not found")
|
||||
|
||||
try:
|
||||
payload = {
|
||||
"model": model.model_name or "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": message}],
|
||||
"max_tokens": 500,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
response = client.post(
|
||||
f"{model.base_url}/chat/completions",
|
||||
json=payload,
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
reply = result.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"reply": reply,
|
||||
"usage": result.get("usage", {})
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
Reference in New Issue
Block a user