Files
AI-VideoAssistant/api/app/routers/tools.py
2026-02-09 00:14:11 +08:00

528 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import Optional, Dict, Any, List
import time
import uuid
import httpx
from datetime import datetime
from ..db import get_db
from ..models import LLMModel, ASRModel, ToolResource
from ..schemas import ToolResourceCreate, ToolResourceOut, ToolResourceUpdate
router = APIRouter(prefix="/tools", tags=["Tools & Autotest"])
# ============ Available Tools ============
TOOL_REGISTRY = {
"search": {
"name": "网络搜索",
"description": "搜索互联网获取最新信息",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "搜索关键词"}
},
"required": ["query"]
}
},
"calculator": {
"name": "计算器",
"description": "执行数学计算",
"parameters": {
"type": "object",
"properties": {
"expression": {"type": "string", "description": "数学表达式,如: 2 + 3 * 4"}
},
"required": ["expression"]
}
},
"weather": {
"name": "天气查询",
"description": "查询指定城市的天气",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string", "description": "城市名称"}
},
"required": ["city"]
}
},
"translate": {
"name": "翻译",
"description": "翻译文本到指定语言",
"parameters": {
"type": "object",
"properties": {
"text": {"type": "string", "description": "要翻译的文本"},
"target_lang": {"type": "string", "description": "目标语言,如: en, ja, ko"}
},
"required": ["text", "target_lang"]
}
},
"knowledge": {
"name": "知识库查询",
"description": "从知识库中检索相关信息",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "查询内容"},
"kb_id": {"type": "string", "description": "知识库ID"}
},
"required": ["query"]
}
},
"code_interpreter": {
"name": "代码执行",
"description": "安全地执行Python代码",
"parameters": {
"type": "object",
"properties": {
"code": {"type": "string", "description": "要执行的Python代码"}
},
"required": ["code"]
}
},
}
TOOL_CATEGORY_MAP = {
"search": "query",
"weather": "query",
"translate": "query",
"knowledge": "query",
"calculator": "system",
"code_interpreter": "system",
}
TOOL_ICON_MAP = {
"search": "Globe",
"weather": "CloudSun",
"translate": "Globe",
"knowledge": "Box",
"calculator": "Terminal",
"code_interpreter": "Terminal",
}
def _builtin_tool_to_resource(tool_id: str, payload: Dict[str, Any]) -> Dict[str, Any]:
return {
"id": tool_id,
"name": payload.get("name", tool_id),
"description": payload.get("description", ""),
"category": TOOL_CATEGORY_MAP.get(tool_id, "system"),
"icon": TOOL_ICON_MAP.get(tool_id, "Wrench"),
"enabled": True,
"is_system": True,
}
@router.get("/list")
def list_available_tools():
"""获取可用的工具列表"""
return {"tools": TOOL_REGISTRY}
@router.get("/list/{tool_id}")
def get_tool_detail(tool_id: str):
"""获取工具详情"""
if tool_id not in TOOL_REGISTRY:
raise HTTPException(status_code=404, detail="Tool not found")
return TOOL_REGISTRY[tool_id]
# ============ Tool Resource CRUD ============
@router.get("/resources")
def list_tool_resources(
category: Optional[str] = None,
enabled: Optional[bool] = None,
include_system: bool = True,
page: int = 1,
limit: int = 100,
db: Session = Depends(get_db),
):
"""获取工具资源列表(内置工具 + 自定义工具)。"""
merged: List[Dict[str, Any]] = []
if include_system:
for tool_id, payload in TOOL_REGISTRY.items():
merged.append(_builtin_tool_to_resource(tool_id, payload))
query = db.query(ToolResource)
if category:
query = query.filter(ToolResource.category == category)
if enabled is not None:
query = query.filter(ToolResource.enabled == enabled)
custom_tools = query.order_by(ToolResource.created_at.desc()).all()
for item in custom_tools:
merged.append({
"id": item.id,
"name": item.name,
"description": item.description,
"category": item.category,
"icon": item.icon,
"enabled": item.enabled,
"is_system": item.is_system,
})
if category:
merged = [item for item in merged if item.get("category") == category]
if enabled is not None:
merged = [item for item in merged if item.get("enabled") == enabled]
total = len(merged)
start = max(page - 1, 0) * limit
end = start + limit
return {"total": total, "page": page, "limit": limit, "list": merged[start:end]}
@router.get("/resources/{id}", response_model=ToolResourceOut)
def get_tool_resource(id: str, db: Session = Depends(get_db)):
"""获取单个工具资源详情。"""
if id in TOOL_REGISTRY:
tool = _builtin_tool_to_resource(id, TOOL_REGISTRY[id])
return ToolResourceOut(**tool)
item = db.query(ToolResource).filter(ToolResource.id == id).first()
if not item:
raise HTTPException(status_code=404, detail="Tool resource not found")
return item
@router.post("/resources", response_model=ToolResourceOut)
def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db)):
"""创建自定义工具资源。"""
candidate_id = (data.id or "").strip()
if candidate_id and candidate_id in TOOL_REGISTRY:
raise HTTPException(status_code=400, detail="Tool ID conflicts with system tool")
item = ToolResource(
id=candidate_id or f"tool_{str(uuid.uuid4())[:8]}",
user_id=1,
name=data.name,
description=data.description,
category=data.category,
icon=data.icon,
enabled=data.enabled,
is_system=False,
)
db.add(item)
db.commit()
db.refresh(item)
return item
@router.put("/resources/{id}", response_model=ToolResourceOut)
def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depends(get_db)):
"""更新自定义工具资源。"""
if id in TOOL_REGISTRY:
raise HTTPException(status_code=400, detail="System tools are read-only")
item = db.query(ToolResource).filter(ToolResource.id == id).first()
if not item:
raise HTTPException(status_code=404, detail="Tool resource not found")
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(item, field, value)
item.updated_at = datetime.utcnow()
db.commit()
db.refresh(item)
return item
@router.delete("/resources/{id}")
def delete_tool_resource(id: str, db: Session = Depends(get_db)):
"""删除自定义工具资源。"""
if id in TOOL_REGISTRY:
raise HTTPException(status_code=400, detail="System tools cannot be deleted")
item = db.query(ToolResource).filter(ToolResource.id == id).first()
if not item:
raise HTTPException(status_code=404, detail="Tool resource not found")
db.delete(item)
db.commit()
return {"message": "Deleted successfully"}
# ============ Autotest ============
class AutotestResult:
"""自动测试结果"""
def __init__(self):
self.id = str(uuid.uuid4())[:8]
self.started_at = time.time()
self.tests = []
self.summary = {"passed": 0, "failed": 0, "total": 0}
def add_test(self, name: str, passed: bool, message: str = "", duration_ms: int = 0):
self.tests.append({
"name": name,
"passed": passed,
"message": message,
"duration_ms": duration_ms
})
if passed:
self.summary["passed"] += 1
else:
self.summary["failed"] += 1
self.summary["total"] += 1
def to_dict(self):
return {
"id": self.id,
"started_at": self.started_at,
"duration_ms": int((time.time() - self.started_at) * 1000),
"tests": self.tests,
"summary": self.summary
}
@router.post("/autotest")
def run_autotest(
llm_model_id: Optional[str] = None,
asr_model_id: Optional[str] = None,
test_llm: bool = True,
test_asr: bool = True,
db: Session = Depends(get_db)
):
"""运行自动测试"""
result = AutotestResult()
# 测试 LLM 模型
if test_llm and llm_model_id:
_test_llm_model(db, llm_model_id, result)
# 测试 ASR 模型
if test_asr and asr_model_id:
_test_asr_model(db, asr_model_id, result)
# 测试 TTS 功能(需要时可添加)
if test_llm and not llm_model_id:
result.add_test(
"LLM Model Check",
False,
"No LLM model ID provided"
)
if test_asr and not asr_model_id:
result.add_test(
"ASR Model Check",
False,
"No ASR model ID provided"
)
return result.to_dict()
@router.post("/autotest/llm/{model_id}")
def autotest_llm_model(model_id: str, db: Session = Depends(get_db)):
"""测试单个LLM模型"""
result = AutotestResult()
_test_llm_model(db, model_id, result)
return result.to_dict()
@router.post("/autotest/asr/{model_id}")
def autotest_asr_model(model_id: str, db: Session = Depends(get_db)):
"""测试单个ASR模型"""
result = AutotestResult()
_test_asr_model(db, model_id, result)
return result.to_dict()
def _test_llm_model(db: Session, model_id: str, result: AutotestResult):
"""内部方法测试LLM模型"""
start_time = time.time()
# 1. 检查模型是否存在
model = db.query(LLMModel).filter(LLMModel.id == model_id).first()
duration_ms = int((time.time() - start_time) * 1000)
if not model:
result.add_test("Model Existence", False, f"Model {model_id} not found", duration_ms)
return
result.add_test("Model Existence", True, f"Found model: {model.name}", duration_ms)
# 2. 测试连接
test_start = time.time()
try:
test_messages = [{"role": "user", "content": "Reply with 'OK'."}]
payload = {
"model": model.model_name or "gpt-3.5-turbo",
"messages": test_messages,
"max_tokens": 10,
"temperature": 0.1,
}
headers = {"Authorization": f"Bearer {model.api_key}"}
with httpx.Client(timeout=30.0) as client:
response = client.post(
f"{model.base_url}/chat/completions",
json=payload,
headers=headers
)
response.raise_for_status()
result_text = response.json()
latency_ms = int((time.time() - test_start) * 1000)
if result_text.get("choices"):
result.add_test("API Connection", True, f"Latency: {latency_ms}ms", latency_ms)
else:
result.add_test("API Connection", False, "Empty response", latency_ms)
except Exception as e:
latency_ms = int((time.time() - test_start) * 1000)
result.add_test("API Connection", False, str(e)[:200], latency_ms)
# 3. 检查模型配置
if model.temperature is not None:
result.add_test("Temperature Setting", True, f"temperature={model.temperature}")
else:
result.add_test("Temperature Setting", True, "Using default")
# 4. 测试流式响应(可选)
if model.type == "text":
test_start = time.time()
try:
with httpx.Client(timeout=30.0) as client:
with client.stream(
"POST",
f"{model.base_url}/chat/completions",
json={
"model": model.model_name or "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Count from 1 to 3."}],
"stream": True,
},
headers=headers
) as response:
response.raise_for_status()
chunk_count = 0
for _ in response.iter_bytes():
chunk_count += 1
latency_ms = int((time.time() - test_start) * 1000)
result.add_test("Streaming Support", True, f"Received {chunk_count} chunks", latency_ms)
except Exception as e:
latency_ms = int((time.time() - test_start) * 1000)
result.add_test("Streaming Support", False, str(e)[:200], latency_ms)
def _test_asr_model(db: Session, model_id: str, result: AutotestResult):
"""内部方法测试ASR模型"""
start_time = time.time()
# 1. 检查模型是否存在
model = db.query(ASRModel).filter(ASRModel.id == model_id).first()
duration_ms = int((time.time() - start_time) * 1000)
if not model:
result.add_test("Model Existence", False, f"Model {model_id} not found", duration_ms)
return
result.add_test("Model Existence", True, f"Found model: {model.name}", duration_ms)
# 2. 测试配置
if model.hotwords:
result.add_test("Hotwords Config", True, f"Hotwords: {len(model.hotwords)} words")
else:
result.add_test("Hotwords Config", True, "No hotwords configured")
# 3. 测试API可用性
test_start = time.time()
try:
headers = {"Authorization": f"Bearer {model.api_key}"}
with httpx.Client(timeout=30.0) as client:
if model.vendor.lower() in ["siliconflow", "paraformer"]:
response = client.get(
f"{model.base_url}/asr",
headers=headers
)
elif model.vendor.lower() == "openai":
response = client.get(
f"{model.base_url}/audio/models",
headers=headers
)
else:
# 通用健康检查
response = client.get(
f"{model.base_url}/health",
headers=headers
)
latency_ms = int((time.time() - test_start) * 1000)
if response.status_code in [200, 405]: # 405 = method not allowed but endpoint exists
result.add_test("API Availability", True, f"Status: {response.status_code}", latency_ms)
else:
result.add_test("API Availability", False, f"Status: {response.status_code}", latency_ms)
except httpx.TimeoutException:
latency_ms = int((time.time() - test_start) * 1000)
result.add_test("API Availability", False, "Connection timeout", latency_ms)
except Exception as e:
latency_ms = int((time.time() - test_start) * 1000)
result.add_test("API Availability", False, str(e)[:200], latency_ms)
# 4. 检查语言配置
if model.language in ["zh", "en", "Multi-lingual"]:
result.add_test("Language Config", True, f"Language: {model.language}")
else:
result.add_test("Language Config", False, f"Unknown language: {model.language}")
# ============ Quick Health Check ============
@router.get("/health")
def health_check():
"""快速健康检查"""
return {
"status": "healthy",
"timestamp": time.time(),
"tools": list(TOOL_REGISTRY.keys())
}
@router.post("/test-message")
def send_test_message(
llm_model_id: str,
message: str = "Hello, this is a test message.",
db: Session = Depends(get_db)
):
"""发送测试消息"""
model = db.query(LLMModel).filter(LLMModel.id == llm_model_id).first()
if not model:
raise HTTPException(status_code=404, detail="LLM Model not found")
try:
payload = {
"model": model.model_name or "gpt-3.5-turbo",
"messages": [{"role": "user", "content": message}],
"max_tokens": 500,
"temperature": 0.7,
}
headers = {"Authorization": f"Bearer {model.api_key}"}
with httpx.Client(timeout=60.0) as client:
response = client.post(
f"{model.base_url}/chat/completions",
json=payload,
headers=headers
)
response.raise_for_status()
result = response.json()
reply = result.get("choices", [{}])[0].get("message", {}).get("content", "")
return {
"success": True,
"reply": reply,
"usage": result.get("usage", {})
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))