564 lines
19 KiB
Python
564 lines
19 KiB
Python
from fastapi import APIRouter, Depends, HTTPException
|
||
from sqlalchemy.orm import Session
|
||
from typing import Optional, Dict, Any, List
|
||
import time
|
||
import uuid
|
||
import httpx
|
||
from datetime import datetime
|
||
|
||
from ..db import get_db
|
||
from ..models import LLMModel, ASRModel, ToolResource
|
||
from ..schemas import ToolResourceCreate, ToolResourceOut, ToolResourceUpdate
|
||
|
||
router = APIRouter(prefix="/tools", tags=["Tools & Autotest"])
|
||
|
||
|
||
# ============ Available Tools ============
|
||
TOOL_REGISTRY = {
|
||
"calculator": {
|
||
"name": "计算器",
|
||
"description": "执行数学计算",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"expression": {"type": "string", "description": "数学表达式,如: 2 + 3 * 4"}
|
||
},
|
||
"required": ["expression"]
|
||
}
|
||
},
|
||
"code_interpreter": {
|
||
"name": "代码执行",
|
||
"description": "安全地执行Python代码",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"code": {"type": "string", "description": "要执行的Python代码"}
|
||
},
|
||
"required": ["code"]
|
||
}
|
||
},
|
||
"current_time": {
|
||
"name": "当前时间",
|
||
"description": "获取当前本地时间",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {},
|
||
"required": []
|
||
}
|
||
},
|
||
"turn_on_camera": {
|
||
"name": "打开摄像头",
|
||
"description": "执行打开摄像头命令",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {},
|
||
"required": []
|
||
}
|
||
},
|
||
"turn_off_camera": {
|
||
"name": "关闭摄像头",
|
||
"description": "执行关闭摄像头命令",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {},
|
||
"required": []
|
||
}
|
||
},
|
||
"increase_volume": {
|
||
"name": "调高音量",
|
||
"description": "提升设备音量",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"step": {"type": "integer", "description": "调整步进,默认1"}
|
||
},
|
||
"required": []
|
||
}
|
||
},
|
||
"decrease_volume": {
|
||
"name": "调低音量",
|
||
"description": "降低设备音量",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"step": {"type": "integer", "description": "调整步进,默认1"}
|
||
},
|
||
"required": []
|
||
}
|
||
},
|
||
}
|
||
|
||
TOOL_CATEGORY_MAP = {
|
||
"calculator": "query",
|
||
"current_time": "query",
|
||
"code_interpreter": "query",
|
||
"turn_on_camera": "system",
|
||
"turn_off_camera": "system",
|
||
"increase_volume": "system",
|
||
"decrease_volume": "system",
|
||
}
|
||
|
||
TOOL_ICON_MAP = {
|
||
"calculator": "Terminal",
|
||
"current_time": "Calendar",
|
||
"code_interpreter": "Terminal",
|
||
"turn_on_camera": "Camera",
|
||
"turn_off_camera": "CameraOff",
|
||
"increase_volume": "Volume2",
|
||
"decrease_volume": "Volume2",
|
||
}
|
||
|
||
TOOL_HTTP_DEFAULTS = {
|
||
}
|
||
|
||
|
||
def _normalize_http_method(method: Optional[str]) -> str:
|
||
normalized = str(method or "GET").strip().upper()
|
||
return normalized if normalized in {"GET", "POST", "PUT", "PATCH", "DELETE"} else "GET"
|
||
|
||
|
||
def _requires_http_request(category: str, tool_id: Optional[str]) -> bool:
|
||
if category != "query":
|
||
return False
|
||
return str(tool_id or "").strip() not in {"calculator", "code_interpreter", "current_time"}
|
||
|
||
|
||
def _validate_query_http_config(*, category: str, tool_id: Optional[str], http_url: Optional[str]) -> None:
|
||
if _requires_http_request(category, tool_id) and not str(http_url or "").strip():
|
||
raise HTTPException(status_code=400, detail="http_url is required for query tools (except calculator/code_interpreter)")
|
||
|
||
def _seed_default_tools_if_empty(db: Session) -> None:
|
||
"""Seed built-in tools only when tool_resources is empty."""
|
||
if db.query(ToolResource).count() > 0:
|
||
return
|
||
for tool_id, payload in TOOL_REGISTRY.items():
|
||
http_defaults = TOOL_HTTP_DEFAULTS.get(tool_id, {})
|
||
db.add(ToolResource(
|
||
id=tool_id,
|
||
user_id=1,
|
||
name=payload.get("name", tool_id),
|
||
description=payload.get("description", ""),
|
||
category=TOOL_CATEGORY_MAP.get(tool_id, "system"),
|
||
icon=TOOL_ICON_MAP.get(tool_id, "Wrench"),
|
||
http_method=_normalize_http_method(http_defaults.get("http_method")),
|
||
http_url=http_defaults.get("http_url"),
|
||
http_headers=http_defaults.get("http_headers") or {},
|
||
http_timeout_ms=int(http_defaults.get("http_timeout_ms") or 10000),
|
||
enabled=True,
|
||
is_system=True,
|
||
))
|
||
db.commit()
|
||
|
||
|
||
def recreate_tool_resources(db: Session) -> None:
|
||
"""Recreate tool resources table content with current built-in defaults."""
|
||
bind = db.get_bind()
|
||
ToolResource.__table__.drop(bind=bind, checkfirst=True)
|
||
ToolResource.__table__.create(bind=bind, checkfirst=True)
|
||
_seed_default_tools_if_empty(db)
|
||
|
||
|
||
@router.get("/list")
|
||
def list_available_tools():
|
||
"""获取可用的工具列表"""
|
||
return {"tools": TOOL_REGISTRY}
|
||
|
||
|
||
@router.get("/list/{tool_id}")
|
||
def get_tool_detail(tool_id: str):
|
||
"""获取工具详情"""
|
||
if tool_id not in TOOL_REGISTRY:
|
||
raise HTTPException(status_code=404, detail="Tool not found")
|
||
return TOOL_REGISTRY[tool_id]
|
||
|
||
|
||
# ============ Tool Resource CRUD ============
|
||
@router.get("/resources")
|
||
def list_tool_resources(
|
||
category: Optional[str] = None,
|
||
enabled: Optional[bool] = None,
|
||
include_system: bool = True,
|
||
page: int = 1,
|
||
limit: int = 100,
|
||
db: Session = Depends(get_db),
|
||
):
|
||
"""获取工具资源列表。system/query 仅表示工具执行类型,不代表权限。"""
|
||
_seed_default_tools_if_empty(db)
|
||
query = db.query(ToolResource)
|
||
if not include_system:
|
||
query = query.filter(ToolResource.is_system == False)
|
||
if category:
|
||
query = query.filter(ToolResource.category == category)
|
||
if enabled is not None:
|
||
query = query.filter(ToolResource.enabled == enabled)
|
||
total = query.count()
|
||
rows = query.order_by(ToolResource.created_at.desc()).offset(max(page - 1, 0) * limit).limit(limit).all()
|
||
return {"total": total, "page": page, "limit": limit, "list": rows}
|
||
|
||
|
||
@router.get("/resources/{id}", response_model=ToolResourceOut)
|
||
def get_tool_resource(id: str, db: Session = Depends(get_db)):
|
||
"""获取单个工具资源详情。"""
|
||
_seed_default_tools_if_empty(db)
|
||
item = db.query(ToolResource).filter(ToolResource.id == id).first()
|
||
if not item:
|
||
raise HTTPException(status_code=404, detail="Tool resource not found")
|
||
return item
|
||
|
||
|
||
@router.post("/resources", response_model=ToolResourceOut)
|
||
def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db)):
|
||
"""创建自定义工具资源。"""
|
||
_seed_default_tools_if_empty(db)
|
||
candidate_id = (data.id or "").strip()
|
||
if candidate_id and db.query(ToolResource).filter(ToolResource.id == candidate_id).first():
|
||
raise HTTPException(status_code=400, detail="Tool ID already exists")
|
||
|
||
_validate_query_http_config(category=data.category, tool_id=candidate_id, http_url=data.http_url)
|
||
|
||
item = ToolResource(
|
||
id=candidate_id or f"tool_{str(uuid.uuid4())[:8]}",
|
||
user_id=1,
|
||
name=data.name,
|
||
description=data.description,
|
||
category=data.category,
|
||
icon=data.icon,
|
||
http_method=_normalize_http_method(data.http_method),
|
||
http_url=(data.http_url or "").strip() or None,
|
||
http_headers=data.http_headers or {},
|
||
http_timeout_ms=max(1000, int(data.http_timeout_ms or 10000)),
|
||
enabled=data.enabled,
|
||
is_system=False,
|
||
)
|
||
db.add(item)
|
||
db.commit()
|
||
db.refresh(item)
|
||
return item
|
||
|
||
|
||
@router.put("/resources/{id}", response_model=ToolResourceOut)
|
||
def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depends(get_db)):
|
||
"""更新工具资源。"""
|
||
_seed_default_tools_if_empty(db)
|
||
item = db.query(ToolResource).filter(ToolResource.id == id).first()
|
||
if not item:
|
||
raise HTTPException(status_code=404, detail="Tool resource not found")
|
||
|
||
update_data = data.model_dump(exclude_unset=True)
|
||
|
||
new_category = update_data.get("category", item.category)
|
||
new_http_url = update_data.get("http_url", item.http_url)
|
||
_validate_query_http_config(category=new_category, tool_id=id, http_url=new_http_url)
|
||
|
||
if "http_method" in update_data:
|
||
update_data["http_method"] = _normalize_http_method(update_data.get("http_method"))
|
||
if "http_timeout_ms" in update_data and update_data.get("http_timeout_ms") is not None:
|
||
update_data["http_timeout_ms"] = max(1000, int(update_data["http_timeout_ms"]))
|
||
|
||
for field, value in update_data.items():
|
||
setattr(item, field, value)
|
||
item.updated_at = datetime.utcnow()
|
||
|
||
db.commit()
|
||
db.refresh(item)
|
||
return item
|
||
|
||
|
||
@router.delete("/resources/{id}")
|
||
def delete_tool_resource(id: str, db: Session = Depends(get_db)):
|
||
"""删除工具资源。"""
|
||
_seed_default_tools_if_empty(db)
|
||
item = db.query(ToolResource).filter(ToolResource.id == id).first()
|
||
if not item:
|
||
raise HTTPException(status_code=404, detail="Tool resource not found")
|
||
db.delete(item)
|
||
db.commit()
|
||
return {"message": "Deleted successfully"}
|
||
|
||
|
||
# ============ Autotest ============
|
||
class AutotestResult:
|
||
"""自动测试结果"""
|
||
|
||
def __init__(self):
|
||
self.id = str(uuid.uuid4())[:8]
|
||
self.started_at = time.time()
|
||
self.tests = []
|
||
self.summary = {"passed": 0, "failed": 0, "total": 0}
|
||
|
||
def add_test(self, name: str, passed: bool, message: str = "", duration_ms: int = 0):
|
||
self.tests.append({
|
||
"name": name,
|
||
"passed": passed,
|
||
"message": message,
|
||
"duration_ms": duration_ms
|
||
})
|
||
if passed:
|
||
self.summary["passed"] += 1
|
||
else:
|
||
self.summary["failed"] += 1
|
||
self.summary["total"] += 1
|
||
|
||
def to_dict(self):
|
||
return {
|
||
"id": self.id,
|
||
"started_at": self.started_at,
|
||
"duration_ms": int((time.time() - self.started_at) * 1000),
|
||
"tests": self.tests,
|
||
"summary": self.summary
|
||
}
|
||
|
||
|
||
@router.post("/autotest")
|
||
def run_autotest(
|
||
llm_model_id: Optional[str] = None,
|
||
asr_model_id: Optional[str] = None,
|
||
test_llm: bool = True,
|
||
test_asr: bool = True,
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""运行自动测试"""
|
||
result = AutotestResult()
|
||
|
||
# 测试 LLM 模型
|
||
if test_llm and llm_model_id:
|
||
_test_llm_model(db, llm_model_id, result)
|
||
|
||
# 测试 ASR 模型
|
||
if test_asr and asr_model_id:
|
||
_test_asr_model(db, asr_model_id, result)
|
||
|
||
# 测试 TTS 功能(需要时可添加)
|
||
if test_llm and not llm_model_id:
|
||
result.add_test(
|
||
"LLM Model Check",
|
||
False,
|
||
"No LLM model ID provided"
|
||
)
|
||
|
||
if test_asr and not asr_model_id:
|
||
result.add_test(
|
||
"ASR Model Check",
|
||
False,
|
||
"No ASR model ID provided"
|
||
)
|
||
|
||
return result.to_dict()
|
||
|
||
|
||
@router.post("/autotest/llm/{model_id}")
|
||
def autotest_llm_model(model_id: str, db: Session = Depends(get_db)):
|
||
"""测试单个LLM模型"""
|
||
result = AutotestResult()
|
||
_test_llm_model(db, model_id, result)
|
||
return result.to_dict()
|
||
|
||
|
||
@router.post("/autotest/asr/{model_id}")
|
||
def autotest_asr_model(model_id: str, db: Session = Depends(get_db)):
|
||
"""测试单个ASR模型"""
|
||
result = AutotestResult()
|
||
_test_asr_model(db, model_id, result)
|
||
return result.to_dict()
|
||
|
||
|
||
def _test_llm_model(db: Session, model_id: str, result: AutotestResult):
|
||
"""内部方法:测试LLM模型"""
|
||
start_time = time.time()
|
||
|
||
# 1. 检查模型是否存在
|
||
model = db.query(LLMModel).filter(LLMModel.id == model_id).first()
|
||
duration_ms = int((time.time() - start_time) * 1000)
|
||
|
||
if not model:
|
||
result.add_test("Model Existence", False, f"Model {model_id} not found", duration_ms)
|
||
return
|
||
|
||
result.add_test("Model Existence", True, f"Found model: {model.name}", duration_ms)
|
||
|
||
# 2. 测试连接
|
||
test_start = time.time()
|
||
try:
|
||
test_messages = [{"role": "user", "content": "Reply with 'OK'."}]
|
||
payload = {
|
||
"model": model.model_name or "gpt-3.5-turbo",
|
||
"messages": test_messages,
|
||
"max_tokens": 10,
|
||
"temperature": 0.1,
|
||
}
|
||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||
|
||
with httpx.Client(timeout=30.0) as client:
|
||
response = client.post(
|
||
f"{model.base_url}/chat/completions",
|
||
json=payload,
|
||
headers=headers
|
||
)
|
||
response.raise_for_status()
|
||
|
||
result_text = response.json()
|
||
latency_ms = int((time.time() - test_start) * 1000)
|
||
|
||
if result_text.get("choices"):
|
||
result.add_test("API Connection", True, f"Latency: {latency_ms}ms", latency_ms)
|
||
else:
|
||
result.add_test("API Connection", False, "Empty response", latency_ms)
|
||
|
||
except Exception as e:
|
||
latency_ms = int((time.time() - test_start) * 1000)
|
||
result.add_test("API Connection", False, str(e)[:200], latency_ms)
|
||
|
||
# 3. 检查模型配置
|
||
if model.temperature is not None:
|
||
result.add_test("Temperature Setting", True, f"temperature={model.temperature}")
|
||
else:
|
||
result.add_test("Temperature Setting", True, "Using default")
|
||
|
||
# 4. 测试流式响应(可选)
|
||
if model.type == "text":
|
||
test_start = time.time()
|
||
try:
|
||
with httpx.Client(timeout=30.0) as client:
|
||
with client.stream(
|
||
"POST",
|
||
f"{model.base_url}/chat/completions",
|
||
json={
|
||
"model": model.model_name or "gpt-3.5-turbo",
|
||
"messages": [{"role": "user", "content": "Count from 1 to 3."}],
|
||
"stream": True,
|
||
},
|
||
headers=headers
|
||
) as response:
|
||
response.raise_for_status()
|
||
chunk_count = 0
|
||
for _ in response.iter_bytes():
|
||
chunk_count += 1
|
||
|
||
latency_ms = int((time.time() - test_start) * 1000)
|
||
result.add_test("Streaming Support", True, f"Received {chunk_count} chunks", latency_ms)
|
||
except Exception as e:
|
||
latency_ms = int((time.time() - test_start) * 1000)
|
||
result.add_test("Streaming Support", False, str(e)[:200], latency_ms)
|
||
|
||
|
||
def _test_asr_model(db: Session, model_id: str, result: AutotestResult):
|
||
"""内部方法:测试ASR模型"""
|
||
start_time = time.time()
|
||
|
||
# 1. 检查模型是否存在
|
||
model = db.query(ASRModel).filter(ASRModel.id == model_id).first()
|
||
duration_ms = int((time.time() - start_time) * 1000)
|
||
|
||
if not model:
|
||
result.add_test("Model Existence", False, f"Model {model_id} not found", duration_ms)
|
||
return
|
||
|
||
result.add_test("Model Existence", True, f"Found model: {model.name}", duration_ms)
|
||
|
||
# 2. 测试配置
|
||
if model.hotwords:
|
||
result.add_test("Hotwords Config", True, f"Hotwords: {len(model.hotwords)} words")
|
||
else:
|
||
result.add_test("Hotwords Config", True, "No hotwords configured")
|
||
|
||
# 3. 测试API可用性
|
||
test_start = time.time()
|
||
try:
|
||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||
|
||
with httpx.Client(timeout=30.0) as client:
|
||
normalized_vendor = (model.vendor or "").strip().lower()
|
||
if normalized_vendor in [
|
||
"openai compatible",
|
||
"openai-compatible",
|
||
"siliconflow", # backward compatibility
|
||
"paraformer",
|
||
]:
|
||
response = client.get(
|
||
f"{model.base_url}/asr",
|
||
headers=headers
|
||
)
|
||
elif model.vendor.lower() == "openai":
|
||
response = client.get(
|
||
f"{model.base_url}/audio/models",
|
||
headers=headers
|
||
)
|
||
else:
|
||
# 通用健康检查
|
||
response = client.get(
|
||
f"{model.base_url}/health",
|
||
headers=headers
|
||
)
|
||
|
||
latency_ms = int((time.time() - test_start) * 1000)
|
||
|
||
if response.status_code in [200, 405]: # 405 = method not allowed but endpoint exists
|
||
result.add_test("API Availability", True, f"Status: {response.status_code}", latency_ms)
|
||
else:
|
||
result.add_test("API Availability", False, f"Status: {response.status_code}", latency_ms)
|
||
|
||
except httpx.TimeoutException:
|
||
latency_ms = int((time.time() - test_start) * 1000)
|
||
result.add_test("API Availability", False, "Connection timeout", latency_ms)
|
||
except Exception as e:
|
||
latency_ms = int((time.time() - test_start) * 1000)
|
||
result.add_test("API Availability", False, str(e)[:200], latency_ms)
|
||
|
||
# 4. 检查语言配置
|
||
if model.language in ["zh", "en", "Multi-lingual"]:
|
||
result.add_test("Language Config", True, f"Language: {model.language}")
|
||
else:
|
||
result.add_test("Language Config", False, f"Unknown language: {model.language}")
|
||
|
||
|
||
# ============ Quick Health Check ============
|
||
@router.get("/health")
|
||
def health_check():
|
||
"""快速健康检查"""
|
||
return {
|
||
"status": "healthy",
|
||
"timestamp": time.time(),
|
||
"tools": list(TOOL_REGISTRY.keys())
|
||
}
|
||
|
||
|
||
@router.post("/test-message")
|
||
def send_test_message(
|
||
llm_model_id: str,
|
||
message: str = "Hello, this is a test message.",
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""发送测试消息"""
|
||
model = db.query(LLMModel).filter(LLMModel.id == llm_model_id).first()
|
||
if not model:
|
||
raise HTTPException(status_code=404, detail="LLM Model not found")
|
||
|
||
try:
|
||
payload = {
|
||
"model": model.model_name or "gpt-3.5-turbo",
|
||
"messages": [{"role": "user", "content": message}],
|
||
"max_tokens": 500,
|
||
"temperature": 0.7,
|
||
}
|
||
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||
|
||
with httpx.Client(timeout=60.0) as client:
|
||
response = client.post(
|
||
f"{model.base_url}/chat/completions",
|
||
json=payload,
|
||
headers=headers
|
||
)
|
||
response.raise_for_status()
|
||
|
||
result = response.json()
|
||
reply = result.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||
|
||
return {
|
||
"success": True,
|
||
"reply": reply,
|
||
"usage": result.get("usage", {})
|
||
}
|
||
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|