Files
AI-VideoAssistant/api/app/routers/tools.py
2026-02-11 11:39:45 +08:00

564 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import Optional, Dict, Any, List
import time
import uuid
import httpx
from datetime import datetime
from ..db import get_db
from ..models import LLMModel, ASRModel, ToolResource
from ..schemas import ToolResourceCreate, ToolResourceOut, ToolResourceUpdate
router = APIRouter(prefix="/tools", tags=["Tools & Autotest"])
# ============ Available Tools ============
TOOL_REGISTRY = {
"calculator": {
"name": "计算器",
"description": "执行数学计算",
"parameters": {
"type": "object",
"properties": {
"expression": {"type": "string", "description": "数学表达式,如: 2 + 3 * 4"}
},
"required": ["expression"]
}
},
"code_interpreter": {
"name": "代码执行",
"description": "安全地执行Python代码",
"parameters": {
"type": "object",
"properties": {
"code": {"type": "string", "description": "要执行的Python代码"}
},
"required": ["code"]
}
},
"current_time": {
"name": "当前时间",
"description": "获取当前本地时间",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
},
"turn_on_camera": {
"name": "打开摄像头",
"description": "执行打开摄像头命令",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
},
"turn_off_camera": {
"name": "关闭摄像头",
"description": "执行关闭摄像头命令",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
},
"increase_volume": {
"name": "调高音量",
"description": "提升设备音量",
"parameters": {
"type": "object",
"properties": {
"step": {"type": "integer", "description": "调整步进默认1"}
},
"required": []
}
},
"decrease_volume": {
"name": "调低音量",
"description": "降低设备音量",
"parameters": {
"type": "object",
"properties": {
"step": {"type": "integer", "description": "调整步进默认1"}
},
"required": []
}
},
}
TOOL_CATEGORY_MAP = {
"calculator": "query",
"current_time": "query",
"code_interpreter": "query",
"turn_on_camera": "system",
"turn_off_camera": "system",
"increase_volume": "system",
"decrease_volume": "system",
}
TOOL_ICON_MAP = {
"calculator": "Terminal",
"current_time": "Calendar",
"code_interpreter": "Terminal",
"turn_on_camera": "Camera",
"turn_off_camera": "CameraOff",
"increase_volume": "Volume2",
"decrease_volume": "Volume2",
}
TOOL_HTTP_DEFAULTS = {
"current_time": {
"http_method": "GET",
"http_url": "https://worldtimeapi.org/api/ip",
"http_headers": {},
"http_timeout_ms": 10000,
},
}
def _normalize_http_method(method: Optional[str]) -> str:
normalized = str(method or "GET").strip().upper()
return normalized if normalized in {"GET", "POST", "PUT", "PATCH", "DELETE"} else "GET"
def _requires_http_request(category: str, tool_id: Optional[str]) -> bool:
if category != "query":
return False
return str(tool_id or "").strip() not in {"calculator", "code_interpreter"}
def _validate_query_http_config(*, category: str, tool_id: Optional[str], http_url: Optional[str]) -> None:
if _requires_http_request(category, tool_id) and not str(http_url or "").strip():
raise HTTPException(status_code=400, detail="http_url is required for query tools (except calculator/code_interpreter)")
def _seed_default_tools_if_empty(db: Session) -> None:
"""Seed built-in tools only when tool_resources is empty."""
if db.query(ToolResource).count() > 0:
return
for tool_id, payload in TOOL_REGISTRY.items():
http_defaults = TOOL_HTTP_DEFAULTS.get(tool_id, {})
db.add(ToolResource(
id=tool_id,
user_id=1,
name=payload.get("name", tool_id),
description=payload.get("description", ""),
category=TOOL_CATEGORY_MAP.get(tool_id, "system"),
icon=TOOL_ICON_MAP.get(tool_id, "Wrench"),
http_method=_normalize_http_method(http_defaults.get("http_method")),
http_url=http_defaults.get("http_url"),
http_headers=http_defaults.get("http_headers") or {},
http_timeout_ms=int(http_defaults.get("http_timeout_ms") or 10000),
enabled=True,
is_system=True,
))
db.commit()
def recreate_tool_resources(db: Session) -> None:
"""Recreate tool resources table content with current built-in defaults."""
bind = db.get_bind()
ToolResource.__table__.drop(bind=bind, checkfirst=True)
ToolResource.__table__.create(bind=bind, checkfirst=True)
_seed_default_tools_if_empty(db)
@router.get("/list")
def list_available_tools():
"""获取可用的工具列表"""
return {"tools": TOOL_REGISTRY}
@router.get("/list/{tool_id}")
def get_tool_detail(tool_id: str):
"""获取工具详情"""
if tool_id not in TOOL_REGISTRY:
raise HTTPException(status_code=404, detail="Tool not found")
return TOOL_REGISTRY[tool_id]
# ============ Tool Resource CRUD ============
@router.get("/resources")
def list_tool_resources(
category: Optional[str] = None,
enabled: Optional[bool] = None,
include_system: bool = True,
page: int = 1,
limit: int = 100,
db: Session = Depends(get_db),
):
"""获取工具资源列表。system/query 仅表示工具执行类型,不代表权限。"""
_seed_default_tools_if_empty(db)
query = db.query(ToolResource)
if not include_system:
query = query.filter(ToolResource.is_system == False)
if category:
query = query.filter(ToolResource.category == category)
if enabled is not None:
query = query.filter(ToolResource.enabled == enabled)
total = query.count()
rows = query.order_by(ToolResource.created_at.desc()).offset(max(page - 1, 0) * limit).limit(limit).all()
return {"total": total, "page": page, "limit": limit, "list": rows}
@router.get("/resources/{id}", response_model=ToolResourceOut)
def get_tool_resource(id: str, db: Session = Depends(get_db)):
"""获取单个工具资源详情。"""
_seed_default_tools_if_empty(db)
item = db.query(ToolResource).filter(ToolResource.id == id).first()
if not item:
raise HTTPException(status_code=404, detail="Tool resource not found")
return item
@router.post("/resources", response_model=ToolResourceOut)
def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db)):
"""创建自定义工具资源。"""
_seed_default_tools_if_empty(db)
candidate_id = (data.id or "").strip()
if candidate_id and db.query(ToolResource).filter(ToolResource.id == candidate_id).first():
raise HTTPException(status_code=400, detail="Tool ID already exists")
_validate_query_http_config(category=data.category, tool_id=candidate_id, http_url=data.http_url)
item = ToolResource(
id=candidate_id or f"tool_{str(uuid.uuid4())[:8]}",
user_id=1,
name=data.name,
description=data.description,
category=data.category,
icon=data.icon,
http_method=_normalize_http_method(data.http_method),
http_url=(data.http_url or "").strip() or None,
http_headers=data.http_headers or {},
http_timeout_ms=max(1000, int(data.http_timeout_ms or 10000)),
enabled=data.enabled,
is_system=False,
)
db.add(item)
db.commit()
db.refresh(item)
return item
@router.put("/resources/{id}", response_model=ToolResourceOut)
def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depends(get_db)):
"""更新工具资源。"""
_seed_default_tools_if_empty(db)
item = db.query(ToolResource).filter(ToolResource.id == id).first()
if not item:
raise HTTPException(status_code=404, detail="Tool resource not found")
update_data = data.model_dump(exclude_unset=True)
new_category = update_data.get("category", item.category)
new_http_url = update_data.get("http_url", item.http_url)
_validate_query_http_config(category=new_category, tool_id=id, http_url=new_http_url)
if "http_method" in update_data:
update_data["http_method"] = _normalize_http_method(update_data.get("http_method"))
if "http_timeout_ms" in update_data and update_data.get("http_timeout_ms") is not None:
update_data["http_timeout_ms"] = max(1000, int(update_data["http_timeout_ms"]))
for field, value in update_data.items():
setattr(item, field, value)
item.updated_at = datetime.utcnow()
db.commit()
db.refresh(item)
return item
@router.delete("/resources/{id}")
def delete_tool_resource(id: str, db: Session = Depends(get_db)):
"""删除工具资源。"""
_seed_default_tools_if_empty(db)
item = db.query(ToolResource).filter(ToolResource.id == id).first()
if not item:
raise HTTPException(status_code=404, detail="Tool resource not found")
db.delete(item)
db.commit()
return {"message": "Deleted successfully"}
# ============ Autotest ============
class AutotestResult:
"""自动测试结果"""
def __init__(self):
self.id = str(uuid.uuid4())[:8]
self.started_at = time.time()
self.tests = []
self.summary = {"passed": 0, "failed": 0, "total": 0}
def add_test(self, name: str, passed: bool, message: str = "", duration_ms: int = 0):
self.tests.append({
"name": name,
"passed": passed,
"message": message,
"duration_ms": duration_ms
})
if passed:
self.summary["passed"] += 1
else:
self.summary["failed"] += 1
self.summary["total"] += 1
def to_dict(self):
return {
"id": self.id,
"started_at": self.started_at,
"duration_ms": int((time.time() - self.started_at) * 1000),
"tests": self.tests,
"summary": self.summary
}
@router.post("/autotest")
def run_autotest(
llm_model_id: Optional[str] = None,
asr_model_id: Optional[str] = None,
test_llm: bool = True,
test_asr: bool = True,
db: Session = Depends(get_db)
):
"""运行自动测试"""
result = AutotestResult()
# 测试 LLM 模型
if test_llm and llm_model_id:
_test_llm_model(db, llm_model_id, result)
# 测试 ASR 模型
if test_asr and asr_model_id:
_test_asr_model(db, asr_model_id, result)
# 测试 TTS 功能(需要时可添加)
if test_llm and not llm_model_id:
result.add_test(
"LLM Model Check",
False,
"No LLM model ID provided"
)
if test_asr and not asr_model_id:
result.add_test(
"ASR Model Check",
False,
"No ASR model ID provided"
)
return result.to_dict()
@router.post("/autotest/llm/{model_id}")
def autotest_llm_model(model_id: str, db: Session = Depends(get_db)):
"""测试单个LLM模型"""
result = AutotestResult()
_test_llm_model(db, model_id, result)
return result.to_dict()
@router.post("/autotest/asr/{model_id}")
def autotest_asr_model(model_id: str, db: Session = Depends(get_db)):
"""测试单个ASR模型"""
result = AutotestResult()
_test_asr_model(db, model_id, result)
return result.to_dict()
def _test_llm_model(db: Session, model_id: str, result: AutotestResult):
"""内部方法测试LLM模型"""
start_time = time.time()
# 1. 检查模型是否存在
model = db.query(LLMModel).filter(LLMModel.id == model_id).first()
duration_ms = int((time.time() - start_time) * 1000)
if not model:
result.add_test("Model Existence", False, f"Model {model_id} not found", duration_ms)
return
result.add_test("Model Existence", True, f"Found model: {model.name}", duration_ms)
# 2. 测试连接
test_start = time.time()
try:
test_messages = [{"role": "user", "content": "Reply with 'OK'."}]
payload = {
"model": model.model_name or "gpt-3.5-turbo",
"messages": test_messages,
"max_tokens": 10,
"temperature": 0.1,
}
headers = {"Authorization": f"Bearer {model.api_key}"}
with httpx.Client(timeout=30.0) as client:
response = client.post(
f"{model.base_url}/chat/completions",
json=payload,
headers=headers
)
response.raise_for_status()
result_text = response.json()
latency_ms = int((time.time() - test_start) * 1000)
if result_text.get("choices"):
result.add_test("API Connection", True, f"Latency: {latency_ms}ms", latency_ms)
else:
result.add_test("API Connection", False, "Empty response", latency_ms)
except Exception as e:
latency_ms = int((time.time() - test_start) * 1000)
result.add_test("API Connection", False, str(e)[:200], latency_ms)
# 3. 检查模型配置
if model.temperature is not None:
result.add_test("Temperature Setting", True, f"temperature={model.temperature}")
else:
result.add_test("Temperature Setting", True, "Using default")
# 4. 测试流式响应(可选)
if model.type == "text":
test_start = time.time()
try:
with httpx.Client(timeout=30.0) as client:
with client.stream(
"POST",
f"{model.base_url}/chat/completions",
json={
"model": model.model_name or "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Count from 1 to 3."}],
"stream": True,
},
headers=headers
) as response:
response.raise_for_status()
chunk_count = 0
for _ in response.iter_bytes():
chunk_count += 1
latency_ms = int((time.time() - test_start) * 1000)
result.add_test("Streaming Support", True, f"Received {chunk_count} chunks", latency_ms)
except Exception as e:
latency_ms = int((time.time() - test_start) * 1000)
result.add_test("Streaming Support", False, str(e)[:200], latency_ms)
def _test_asr_model(db: Session, model_id: str, result: AutotestResult):
"""内部方法测试ASR模型"""
start_time = time.time()
# 1. 检查模型是否存在
model = db.query(ASRModel).filter(ASRModel.id == model_id).first()
duration_ms = int((time.time() - start_time) * 1000)
if not model:
result.add_test("Model Existence", False, f"Model {model_id} not found", duration_ms)
return
result.add_test("Model Existence", True, f"Found model: {model.name}", duration_ms)
# 2. 测试配置
if model.hotwords:
result.add_test("Hotwords Config", True, f"Hotwords: {len(model.hotwords)} words")
else:
result.add_test("Hotwords Config", True, "No hotwords configured")
# 3. 测试API可用性
test_start = time.time()
try:
headers = {"Authorization": f"Bearer {model.api_key}"}
with httpx.Client(timeout=30.0) as client:
if model.vendor.lower() in ["siliconflow", "paraformer"]:
response = client.get(
f"{model.base_url}/asr",
headers=headers
)
elif model.vendor.lower() == "openai":
response = client.get(
f"{model.base_url}/audio/models",
headers=headers
)
else:
# 通用健康检查
response = client.get(
f"{model.base_url}/health",
headers=headers
)
latency_ms = int((time.time() - test_start) * 1000)
if response.status_code in [200, 405]: # 405 = method not allowed but endpoint exists
result.add_test("API Availability", True, f"Status: {response.status_code}", latency_ms)
else:
result.add_test("API Availability", False, f"Status: {response.status_code}", latency_ms)
except httpx.TimeoutException:
latency_ms = int((time.time() - test_start) * 1000)
result.add_test("API Availability", False, "Connection timeout", latency_ms)
except Exception as e:
latency_ms = int((time.time() - test_start) * 1000)
result.add_test("API Availability", False, str(e)[:200], latency_ms)
# 4. 检查语言配置
if model.language in ["zh", "en", "Multi-lingual"]:
result.add_test("Language Config", True, f"Language: {model.language}")
else:
result.add_test("Language Config", False, f"Unknown language: {model.language}")
# ============ Quick Health Check ============
@router.get("/health")
def health_check():
"""快速健康检查"""
return {
"status": "healthy",
"timestamp": time.time(),
"tools": list(TOOL_REGISTRY.keys())
}
@router.post("/test-message")
def send_test_message(
llm_model_id: str,
message: str = "Hello, this is a test message.",
db: Session = Depends(get_db)
):
"""发送测试消息"""
model = db.query(LLMModel).filter(LLMModel.id == llm_model_id).first()
if not model:
raise HTTPException(status_code=404, detail="LLM Model not found")
try:
payload = {
"model": model.model_name or "gpt-3.5-turbo",
"messages": [{"role": "user", "content": message}],
"max_tokens": 500,
"temperature": 0.7,
}
headers = {"Authorization": f"Bearer {model.api_key}"}
with httpx.Client(timeout=60.0) as client:
response = client.post(
f"{model.base_url}/chat/completions",
json=payload,
headers=headers
)
response.raise_for_status()
result = response.json()
reply = result.get("choices", [{}])[0].get("message", {}).get("content", "")
return {
"success": True,
"reply": reply,
"usage": result.get("usage", {})
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))