380 lines
12 KiB
Python
380 lines
12 KiB
Python
from fastapi import APIRouter, Depends, HTTPException
|
||
from sqlalchemy.orm import Session
|
||
from typing import Optional, Dict, Any
|
||
import time
|
||
import uuid
|
||
import httpx
|
||
|
||
from ..db import get_db
|
||
from ..models import LLMModel, ASRModel
|
||
|
||
router = APIRouter(prefix="/tools", tags=["Tools & Autotest"])
|
||
|
||
|
||
# ============ Available Tools ============
|
||
TOOL_REGISTRY = {
|
||
"search": {
|
||
"name": "网络搜索",
|
||
"description": "搜索互联网获取最新信息",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"query": {"type": "string", "description": "搜索关键词"}
|
||
},
|
||
"required": ["query"]
|
||
}
|
||
},
|
||
"calculator": {
|
||
"name": "计算器",
|
||
"description": "执行数学计算",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"expression": {"type": "string", "description": "数学表达式,如: 2 + 3 * 4"}
|
||
},
|
||
"required": ["expression"]
|
||
}
|
||
},
|
||
"weather": {
|
||
"name": "天气查询",
|
||
"description": "查询指定城市的天气",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"city": {"type": "string", "description": "城市名称"}
|
||
},
|
||
"required": ["city"]
|
||
}
|
||
},
|
||
"translate": {
|
||
"name": "翻译",
|
||
"description": "翻译文本到指定语言",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"text": {"type": "string", "description": "要翻译的文本"},
|
||
"target_lang": {"type": "string", "description": "目标语言,如: en, ja, ko"}
|
||
},
|
||
"required": ["text", "target_lang"]
|
||
}
|
||
},
|
||
"knowledge": {
|
||
"name": "知识库查询",
|
||
"description": "从知识库中检索相关信息",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"query": {"type": "string", "description": "查询内容"},
|
||
"kb_id": {"type": "string", "description": "知识库ID"}
|
||
},
|
||
"required": ["query"]
|
||
}
|
||
},
|
||
"code_interpreter": {
|
||
"name": "代码执行",
|
||
"description": "安全地执行Python代码",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"code": {"type": "string", "description": "要执行的Python代码"}
|
||
},
|
||
"required": ["code"]
|
||
}
|
||
},
|
||
}
|
||
|
||
|
||
@router.get("/list")
|
||
def list_available_tools():
|
||
"""获取可用的工具列表"""
|
||
return {"tools": TOOL_REGISTRY}
|
||
|
||
|
||
@router.get("/list/{tool_id}")
|
||
def get_tool_detail(tool_id: str):
|
||
"""获取工具详情"""
|
||
if tool_id not in TOOL_REGISTRY:
|
||
raise HTTPException(status_code=404, detail="Tool not found")
|
||
return TOOL_REGISTRY[tool_id]
|
||
|
||
|
||
# ============ Autotest ============
|
||
class AutotestResult:
|
||
"""自动测试结果"""
|
||
|
||
def __init__(self):
|
||
self.id = str(uuid.uuid4())[:8]
|
||
self.started_at = time.time()
|
||
self.tests = []
|
||
self.summary = {"passed": 0, "failed": 0, "total": 0}
|
||
|
||
def add_test(self, name: str, passed: bool, message: str = "", duration_ms: int = 0):
|
||
self.tests.append({
|
||
"name": name,
|
||
"passed": passed,
|
||
"message": message,
|
||
"duration_ms": duration_ms
|
||
})
|
||
if passed:
|
||
self.summary["passed"] += 1
|
||
else:
|
||
self.summary["failed"] += 1
|
||
self.summary["total"] += 1
|
||
|
||
def to_dict(self):
|
||
return {
|
||
"id": self.id,
|
||
"started_at": self.started_at,
|
||
"duration_ms": int((time.time() - self.started_at) * 1000),
|
||
"tests": self.tests,
|
||
"summary": self.summary
|
||
}
|
||
|
||
|
||
@router.post("/autotest")
|
||
def run_autotest(
|
||
llm_model_id: Optional[str] = None,
|
||
asr_model_id: Optional[str] = None,
|
||
test_llm: bool = True,
|
||
test_asr: bool = True,
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""运行自动测试"""
|
||
result = AutotestResult()
|
||
|
||
# 测试 LLM 模型
|
||
if test_llm and llm_model_id:
|
||
_test_llm_model(db, llm_model_id, result)
|
||
|
||
# 测试 ASR 模型
|
||
if test_asr and asr_model_id:
|
||
_test_asr_model(db, asr_model_id, result)
|
||
|
||
# 测试 TTS 功能(需要时可添加)
|
||
if test_llm and not llm_model_id:
|
||
result.add_test(
|
||
"LLM Model Check",
|
||
False,
|
||
"No LLM model ID provided"
|
||
)
|
||
|
||
if test_asr and not asr_model_id:
|
||
result.add_test(
|
||
"ASR Model Check",
|
||
False,
|
||
"No ASR model ID provided"
|
||
)
|
||
|
||
return result.to_dict()
|
||
|
||
|
||
@router.post("/autotest/llm/{model_id}")
|
||
def autotest_llm_model(model_id: str, db: Session = Depends(get_db)):
|
||
"""测试单个LLM模型"""
|
||
result = AutotestResult()
|
||
_test_llm_model(db, model_id, result)
|
||
return result.to_dict()
|
||
|
||
|
||
@router.post("/autotest/asr/{model_id}")
|
||
def autotest_asr_model(model_id: str, db: Session = Depends(get_db)):
|
||
"""测试单个ASR模型"""
|
||
result = AutotestResult()
|
||
_test_asr_model(db, model_id, result)
|
||
return result.to_dict()
|
||
|
||
|
||
def _test_llm_model(db: Session, model_id: str, result: AutotestResult):
|
||
"""内部方法:测试LLM模型"""
|
||
start_time = time.time()
|
||
|
||
# 1. 检查模型是否存在
|
||
model = db.query(LLMModel).filter(LLMModel.id == model_id).first()
|
||
duration_ms = int((time.time() - start_time) * 1000)
|
||
|
||
if not model:
|
||
result.add_test("Model Existence", False, f"Model {model_id} not found", duration_ms)
|
||
return
|
||
|
||
result.add_test("Model Existence", True, f"Found model: {model.name}", duration_ms)
|
||
|
||
# 2. 测试连接
|
||
test_start = time.time()
|
||
try:
|
||
test_messages = [{"role": "user", "content": "Reply with 'OK'."}]
|
||
payload = {
|
||
"model": model.model_name or "gpt-3.5-turbo",
|
||
"messages": test_messages,
|
||
"max_tokens": 10,
|
||
"temperature": 0.1,
|
||
}
|
||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||
|
||
with httpx.Client(timeout=30.0) as client:
|
||
response = client.post(
|
||
f"{model.base_url}/chat/completions",
|
||
json=payload,
|
||
headers=headers
|
||
)
|
||
response.raise_for_status()
|
||
|
||
result_text = response.json()
|
||
latency_ms = int((time.time() - test_start) * 1000)
|
||
|
||
if result_text.get("choices"):
|
||
result.add_test("API Connection", True, f"Latency: {latency_ms}ms", latency_ms)
|
||
else:
|
||
result.add_test("API Connection", False, "Empty response", latency_ms)
|
||
|
||
except Exception as e:
|
||
latency_ms = int((time.time() - test_start) * 1000)
|
||
result.add_test("API Connection", False, str(e)[:200], latency_ms)
|
||
|
||
# 3. 检查模型配置
|
||
if model.temperature is not None:
|
||
result.add_test("Temperature Setting", True, f"temperature={model.temperature}")
|
||
else:
|
||
result.add_test("Temperature Setting", True, "Using default")
|
||
|
||
# 4. 测试流式响应(可选)
|
||
if model.type == "text":
|
||
test_start = time.time()
|
||
try:
|
||
with httpx.Client(timeout=30.0) as client:
|
||
with client.stream(
|
||
"POST",
|
||
f"{model.base_url}/chat/completions",
|
||
json={
|
||
"model": model.model_name or "gpt-3.5-turbo",
|
||
"messages": [{"role": "user", "content": "Count from 1 to 3."}],
|
||
"stream": True,
|
||
},
|
||
headers=headers
|
||
) as response:
|
||
response.raise_for_status()
|
||
chunk_count = 0
|
||
for _ in response.iter_bytes():
|
||
chunk_count += 1
|
||
|
||
latency_ms = int((time.time() - test_start) * 1000)
|
||
result.add_test("Streaming Support", True, f"Received {chunk_count} chunks", latency_ms)
|
||
except Exception as e:
|
||
latency_ms = int((time.time() - test_start) * 1000)
|
||
result.add_test("Streaming Support", False, str(e)[:200], latency_ms)
|
||
|
||
|
||
def _test_asr_model(db: Session, model_id: str, result: AutotestResult):
|
||
"""内部方法:测试ASR模型"""
|
||
start_time = time.time()
|
||
|
||
# 1. 检查模型是否存在
|
||
model = db.query(ASRModel).filter(ASRModel.id == model_id).first()
|
||
duration_ms = int((time.time() - start_time) * 1000)
|
||
|
||
if not model:
|
||
result.add_test("Model Existence", False, f"Model {model_id} not found", duration_ms)
|
||
return
|
||
|
||
result.add_test("Model Existence", True, f"Found model: {model.name}", duration_ms)
|
||
|
||
# 2. 测试配置
|
||
if model.hotwords:
|
||
result.add_test("Hotwords Config", True, f"Hotwords: {len(model.hotwords)} words")
|
||
else:
|
||
result.add_test("Hotwords Config", True, "No hotwords configured")
|
||
|
||
# 3. 测试API可用性
|
||
test_start = time.time()
|
||
try:
|
||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||
|
||
with httpx.Client(timeout=30.0) as client:
|
||
if model.vendor.lower() in ["siliconflow", "paraformer"]:
|
||
response = client.get(
|
||
f"{model.base_url}/asr",
|
||
headers=headers
|
||
)
|
||
elif model.vendor.lower() == "openai":
|
||
response = client.get(
|
||
f"{model.base_url}/audio/models",
|
||
headers=headers
|
||
)
|
||
else:
|
||
# 通用健康检查
|
||
response = client.get(
|
||
f"{model.base_url}/health",
|
||
headers=headers
|
||
)
|
||
|
||
latency_ms = int((time.time() - test_start) * 1000)
|
||
|
||
if response.status_code in [200, 405]: # 405 = method not allowed but endpoint exists
|
||
result.add_test("API Availability", True, f"Status: {response.status_code}", latency_ms)
|
||
else:
|
||
result.add_test("API Availability", False, f"Status: {response.status_code}", latency_ms)
|
||
|
||
except httpx.TimeoutException:
|
||
latency_ms = int((time.time() - test_start) * 1000)
|
||
result.add_test("API Availability", False, "Connection timeout", latency_ms)
|
||
except Exception as e:
|
||
latency_ms = int((time.time() - test_start) * 1000)
|
||
result.add_test("API Availability", False, str(e)[:200], latency_ms)
|
||
|
||
# 4. 检查语言配置
|
||
if model.language in ["zh", "en", "Multi-lingual"]:
|
||
result.add_test("Language Config", True, f"Language: {model.language}")
|
||
else:
|
||
result.add_test("Language Config", False, f"Unknown language: {model.language}")
|
||
|
||
|
||
# ============ Quick Health Check ============
|
||
@router.get("/health")
|
||
def health_check():
|
||
"""快速健康检查"""
|
||
return {
|
||
"status": "healthy",
|
||
"timestamp": time.time(),
|
||
"tools": list(TOOL_REGISTRY.keys())
|
||
}
|
||
|
||
|
||
@router.post("/test-message")
|
||
def send_test_message(
|
||
llm_model_id: str,
|
||
message: str = "Hello, this is a test message.",
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""发送测试消息"""
|
||
model = db.query(LLMModel).filter(LLMModel.id == llm_model_id).first()
|
||
if not model:
|
||
raise HTTPException(status_code=404, detail="LLM Model not found")
|
||
|
||
try:
|
||
payload = {
|
||
"model": model.model_name or "gpt-3.5-turbo",
|
||
"messages": [{"role": "user", "content": message}],
|
||
"max_tokens": 500,
|
||
"temperature": 0.7,
|
||
}
|
||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||
|
||
with httpx.Client(timeout=60.0) as client:
|
||
response = client.post(
|
||
f"{model.base_url}/chat/completions",
|
||
json=payload,
|
||
headers=headers
|
||
)
|
||
response.raise_for_status()
|
||
|
||
result = response.json()
|
||
reply = result.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||
|
||
return {
|
||
"success": True,
|
||
"reply": reply,
|
||
"usage": result.get("usage", {})
|
||
}
|
||
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|