Files
AI-VideoAssistant/api/app/routers/tools.py

743 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from sqlalchemy import inspect, text
from typing import Optional, Dict, Any, List
import time
import uuid
import httpx
from datetime import datetime
from ..db import get_db
from ..models import LLMModel, ASRModel, ToolResource
from ..schemas import ToolResourceCreate, ToolResourceOut, ToolResourceUpdate
router = APIRouter(prefix="/tools", tags=["Tools & Autotest"])
# ============ Available Tools ============
TOOL_REGISTRY = {
"calculator": {
"name": "计算器",
"description": "执行数学计算",
"parameters": {
"type": "object",
"properties": {
"expression": {"type": "string", "description": "数学表达式,如: 2 + 3 * 4"}
},
"required": ["expression"]
}
},
"code_interpreter": {
"name": "代码执行",
"description": "安全地执行Python代码",
"parameters": {
"type": "object",
"properties": {
"code": {"type": "string", "description": "要执行的Python代码"}
},
"required": ["code"]
}
},
"current_time": {
"name": "当前时间",
"description": "获取当前本地时间",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
},
"turn_on_camera": {
"name": "打开摄像头",
"description": "执行打开摄像头命令",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
},
"turn_off_camera": {
"name": "关闭摄像头",
"description": "执行关闭摄像头命令",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
},
"increase_volume": {
"name": "调高音量",
"description": "提升设备音量",
"parameters": {
"type": "object",
"properties": {
"step": {"type": "integer", "description": "调整步进默认1"}
},
"required": []
}
},
"decrease_volume": {
"name": "调低音量",
"description": "降低设备音量",
"parameters": {
"type": "object",
"properties": {
"step": {"type": "integer", "description": "调整步进默认1"}
},
"required": []
}
},
"voice_message_prompt": {
"name": "语音消息提示",
"description": "播报一条语音提示消息",
"parameters": {
"type": "object",
"properties": {
"msg": {"type": "string", "description": "要播报的消息文本"}
},
"required": ["msg"]
}
},
"text_msg_prompt": {
"name": "文本消息提示",
"description": "显示一条文本弹窗提示",
"parameters": {
"type": "object",
"properties": {
"msg": {"type": "string", "description": "提示文本内容"}
},
"required": ["msg"]
}
},
"voice_choice_prompt": {
"name": "语音选项提示",
"description": "播报问题并展示可选项,等待用户选择后回传结果",
"parameters": {
"type": "object",
"properties": {
"question": {"type": "string", "description": "向用户展示的问题文本"},
"options": {
"type": "array",
"description": "可选项(字符串或含 id/label/value 的对象)",
"minItems": 2,
"items": {
"anyOf": [
{"type": "string"},
{
"type": "object",
"properties": {
"id": {"type": "string"},
"label": {"type": "string"},
"value": {"type": "string"}
},
"required": ["label"]
}
]
}
},
"voice_text": {"type": "string", "description": "可选,单独指定播报文本;为空则播报 question"}
},
"required": ["question", "options"]
}
},
"text_choice_prompt": {
"name": "文本选项提示",
"description": "显示文本选项弹窗并等待用户选择后回传结果",
"parameters": {
"type": "object",
"properties": {
"question": {"type": "string", "description": "向用户展示的问题文本"},
"options": {
"type": "array",
"description": "可选项(字符串或含 id/label/value 的对象)",
"minItems": 2,
"items": {
"anyOf": [
{"type": "string"},
{
"type": "object",
"properties": {
"id": {"type": "string"},
"label": {"type": "string"},
"value": {"type": "string"}
},
"required": ["label"]
}
]
}
}
},
"required": ["question", "options"]
}
},
}
TOOL_CATEGORY_MAP = {
"calculator": "query",
"current_time": "query",
"code_interpreter": "query",
"turn_on_camera": "system",
"turn_off_camera": "system",
"increase_volume": "system",
"decrease_volume": "system",
"voice_message_prompt": "system",
"text_msg_prompt": "system",
"voice_choice_prompt": "system",
"text_choice_prompt": "system",
}
TOOL_ICON_MAP = {
"calculator": "Terminal",
"current_time": "Calendar",
"code_interpreter": "Terminal",
"turn_on_camera": "Camera",
"turn_off_camera": "CameraOff",
"increase_volume": "Volume2",
"decrease_volume": "Volume2",
"voice_message_prompt": "Volume2",
"text_msg_prompt": "Terminal",
"voice_choice_prompt": "Volume2",
"text_choice_prompt": "Terminal",
}
TOOL_HTTP_DEFAULTS = {
}
TOOL_PARAMETER_DEFAULTS = {
"increase_volume": {"step": 1},
"decrease_volume": {"step": 1},
}
TOOL_WAIT_FOR_RESPONSE_DEFAULTS = {
"text_msg_prompt": True,
"voice_choice_prompt": True,
"text_choice_prompt": True,
}
def _normalize_parameter_schema(value: Any, *, tool_id: Optional[str] = None) -> Dict[str, Any]:
if not isinstance(value, dict):
value = {}
normalized = dict(value)
if not normalized:
fallback = TOOL_REGISTRY.get(str(tool_id or "").strip(), {}).get("parameters")
if isinstance(fallback, dict):
normalized = dict(fallback)
normalized.setdefault("type", "object")
if normalized.get("type") != "object":
raise HTTPException(status_code=400, detail="parameter_schema.type must be 'object'")
properties = normalized.get("properties")
if not isinstance(properties, dict):
normalized["properties"] = {}
required = normalized.get("required")
if required is None:
normalized["required"] = []
elif not isinstance(required, list):
raise HTTPException(status_code=400, detail="parameter_schema.required must be an array")
return normalized
def _normalize_parameter_defaults(value: Any) -> Dict[str, Any]:
if value is None:
return {}
if not isinstance(value, dict):
raise HTTPException(status_code=400, detail="parameter_defaults must be an object")
return dict(value)
def _ensure_tool_resource_schema(db: Session) -> None:
"""Apply lightweight SQLite migrations for newly added tool_resources columns."""
bind = db.get_bind()
inspector = inspect(bind)
try:
columns = {col["name"] for col in inspector.get_columns("tool_resources")}
except Exception:
return
altered = False
if "parameter_schema" not in columns:
db.execute(text("ALTER TABLE tool_resources ADD COLUMN parameter_schema JSON"))
altered = True
if "parameter_defaults" not in columns:
db.execute(text("ALTER TABLE tool_resources ADD COLUMN parameter_defaults JSON"))
altered = True
if "wait_for_response" not in columns:
db.execute(text("ALTER TABLE tool_resources ADD COLUMN wait_for_response BOOLEAN DEFAULT 0"))
altered = True
if altered:
db.commit()
def _normalize_http_method(method: Optional[str]) -> str:
normalized = str(method or "GET").strip().upper()
return normalized if normalized in {"GET", "POST", "PUT", "PATCH", "DELETE"} else "GET"
def _requires_http_request(category: str, tool_id: Optional[str]) -> bool:
if category != "query":
return False
return str(tool_id or "").strip() not in {"calculator", "code_interpreter", "current_time"}
def _validate_query_http_config(*, category: str, tool_id: Optional[str], http_url: Optional[str]) -> None:
if _requires_http_request(category, tool_id) and not str(http_url or "").strip():
raise HTTPException(status_code=400, detail="http_url is required for query tools (except calculator/code_interpreter)")
def _seed_default_tools_if_empty(db: Session) -> None:
"""Ensure built-in tools exist in tool_resources without overriding custom edits."""
_ensure_tool_resource_schema(db)
existing_ids = {
str(item[0])
for item in db.query(ToolResource.id).all()
}
changed = False
for tool_id, payload in TOOL_REGISTRY.items():
if tool_id in existing_ids:
continue
http_defaults = TOOL_HTTP_DEFAULTS.get(tool_id, {})
db.add(ToolResource(
id=tool_id,
user_id=1,
name=payload.get("name", tool_id),
description=payload.get("description", ""),
category=TOOL_CATEGORY_MAP.get(tool_id, "system"),
icon=TOOL_ICON_MAP.get(tool_id, "Wrench"),
http_method=_normalize_http_method(http_defaults.get("http_method")),
http_url=http_defaults.get("http_url"),
http_headers=http_defaults.get("http_headers") or {},
http_timeout_ms=int(http_defaults.get("http_timeout_ms") or 10000),
parameter_schema=_normalize_parameter_schema(payload.get("parameters"), tool_id=tool_id),
parameter_defaults=_normalize_parameter_defaults(TOOL_PARAMETER_DEFAULTS.get(tool_id)),
wait_for_response=bool(TOOL_WAIT_FOR_RESPONSE_DEFAULTS.get(tool_id, False)),
enabled=True,
is_system=True,
))
changed = True
if changed:
db.commit()
def recreate_tool_resources(db: Session) -> None:
"""Recreate tool resources table content with current built-in defaults."""
bind = db.get_bind()
ToolResource.__table__.drop(bind=bind, checkfirst=True)
ToolResource.__table__.create(bind=bind, checkfirst=True)
_seed_default_tools_if_empty(db)
@router.get("/list")
def list_available_tools():
"""获取可用的工具列表"""
return {"tools": TOOL_REGISTRY}
@router.get("/list/{tool_id}")
def get_tool_detail(tool_id: str):
"""获取工具详情"""
if tool_id not in TOOL_REGISTRY:
raise HTTPException(status_code=404, detail="Tool not found")
return TOOL_REGISTRY[tool_id]
# ============ Tool Resource CRUD ============
@router.get("/resources")
def list_tool_resources(
category: Optional[str] = None,
enabled: Optional[bool] = None,
include_system: bool = True,
page: int = 1,
limit: int = 100,
db: Session = Depends(get_db),
):
"""获取工具资源列表。system/query 仅表示工具执行类型,不代表权限。"""
_seed_default_tools_if_empty(db)
query = db.query(ToolResource)
if not include_system:
query = query.filter(ToolResource.is_system == False)
if category:
query = query.filter(ToolResource.category == category)
if enabled is not None:
query = query.filter(ToolResource.enabled == enabled)
total = query.count()
rows = query.order_by(ToolResource.created_at.desc()).offset(max(page - 1, 0) * limit).limit(limit).all()
return {"total": total, "page": page, "limit": limit, "list": rows}
@router.get("/resources/{id}", response_model=ToolResourceOut)
def get_tool_resource(id: str, db: Session = Depends(get_db)):
"""获取单个工具资源详情。"""
_seed_default_tools_if_empty(db)
item = db.query(ToolResource).filter(ToolResource.id == id).first()
if not item:
raise HTTPException(status_code=404, detail="Tool resource not found")
return item
@router.post("/resources", response_model=ToolResourceOut)
def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db)):
"""创建自定义工具资源。"""
_seed_default_tools_if_empty(db)
candidate_id = (data.id or "").strip()
if candidate_id and db.query(ToolResource).filter(ToolResource.id == candidate_id).first():
raise HTTPException(status_code=400, detail="Tool ID already exists")
_validate_query_http_config(category=data.category, tool_id=candidate_id, http_url=data.http_url)
parameter_schema = _normalize_parameter_schema(data.parameter_schema, tool_id=candidate_id)
parameter_defaults = _normalize_parameter_defaults(data.parameter_defaults)
item = ToolResource(
id=candidate_id or f"tool_{str(uuid.uuid4())[:8]}",
user_id=1,
name=data.name,
description=data.description,
category=data.category,
icon=data.icon,
http_method=_normalize_http_method(data.http_method),
http_url=(data.http_url or "").strip() or None,
http_headers=data.http_headers or {},
http_timeout_ms=max(1000, int(data.http_timeout_ms or 10000)),
parameter_schema=parameter_schema,
parameter_defaults=parameter_defaults,
wait_for_response=bool(data.wait_for_response) if data.category == "system" else False,
enabled=data.enabled,
is_system=False,
)
db.add(item)
db.commit()
db.refresh(item)
return item
@router.put("/resources/{id}", response_model=ToolResourceOut)
def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depends(get_db)):
"""更新工具资源。"""
_seed_default_tools_if_empty(db)
item = db.query(ToolResource).filter(ToolResource.id == id).first()
if not item:
raise HTTPException(status_code=404, detail="Tool resource not found")
update_data = data.model_dump(exclude_unset=True)
new_category = update_data.get("category", item.category)
new_http_url = update_data.get("http_url", item.http_url)
_validate_query_http_config(category=new_category, tool_id=id, http_url=new_http_url)
if "http_method" in update_data:
update_data["http_method"] = _normalize_http_method(update_data.get("http_method"))
if "http_timeout_ms" in update_data and update_data.get("http_timeout_ms") is not None:
update_data["http_timeout_ms"] = max(1000, int(update_data["http_timeout_ms"]))
if "parameter_schema" in update_data:
update_data["parameter_schema"] = _normalize_parameter_schema(update_data.get("parameter_schema"), tool_id=id)
if "parameter_defaults" in update_data:
update_data["parameter_defaults"] = _normalize_parameter_defaults(update_data.get("parameter_defaults"))
if new_category != "system":
update_data["wait_for_response"] = False
for field, value in update_data.items():
setattr(item, field, value)
item.updated_at = datetime.utcnow()
db.commit()
db.refresh(item)
return item
@router.delete("/resources/{id}")
def delete_tool_resource(id: str, db: Session = Depends(get_db)):
"""删除工具资源。"""
_seed_default_tools_if_empty(db)
item = db.query(ToolResource).filter(ToolResource.id == id).first()
if not item:
raise HTTPException(status_code=404, detail="Tool resource not found")
db.delete(item)
db.commit()
return {"message": "Deleted successfully"}
# ============ Autotest ============
class AutotestResult:
"""自动测试结果"""
def __init__(self):
self.id = str(uuid.uuid4())[:8]
self.started_at = time.time()
self.tests = []
self.summary = {"passed": 0, "failed": 0, "total": 0}
def add_test(self, name: str, passed: bool, message: str = "", duration_ms: int = 0):
self.tests.append({
"name": name,
"passed": passed,
"message": message,
"duration_ms": duration_ms
})
if passed:
self.summary["passed"] += 1
else:
self.summary["failed"] += 1
self.summary["total"] += 1
def to_dict(self):
return {
"id": self.id,
"started_at": self.started_at,
"duration_ms": int((time.time() - self.started_at) * 1000),
"tests": self.tests,
"summary": self.summary
}
@router.post("/autotest")
def run_autotest(
llm_model_id: Optional[str] = None,
asr_model_id: Optional[str] = None,
test_llm: bool = True,
test_asr: bool = True,
db: Session = Depends(get_db)
):
"""运行自动测试"""
result = AutotestResult()
# 测试 LLM 模型
if test_llm and llm_model_id:
_test_llm_model(db, llm_model_id, result)
# 测试 ASR 模型
if test_asr and asr_model_id:
_test_asr_model(db, asr_model_id, result)
# 测试 TTS 功能(需要时可添加)
if test_llm and not llm_model_id:
result.add_test(
"LLM Model Check",
False,
"No LLM model ID provided"
)
if test_asr and not asr_model_id:
result.add_test(
"ASR Model Check",
False,
"No ASR model ID provided"
)
return result.to_dict()
@router.post("/autotest/llm/{model_id}")
def autotest_llm_model(model_id: str, db: Session = Depends(get_db)):
"""测试单个LLM模型"""
result = AutotestResult()
_test_llm_model(db, model_id, result)
return result.to_dict()
@router.post("/autotest/asr/{model_id}")
def autotest_asr_model(model_id: str, db: Session = Depends(get_db)):
"""测试单个ASR模型"""
result = AutotestResult()
_test_asr_model(db, model_id, result)
return result.to_dict()
def _test_llm_model(db: Session, model_id: str, result: AutotestResult):
"""内部方法测试LLM模型"""
start_time = time.time()
# 1. 检查模型是否存在
model = db.query(LLMModel).filter(LLMModel.id == model_id).first()
duration_ms = int((time.time() - start_time) * 1000)
if not model:
result.add_test("Model Existence", False, f"Model {model_id} not found", duration_ms)
return
result.add_test("Model Existence", True, f"Found model: {model.name}", duration_ms)
# 2. 测试连接
test_start = time.time()
try:
test_messages = [{"role": "user", "content": "Reply with 'OK'."}]
payload = {
"model": model.model_name or "gpt-3.5-turbo",
"messages": test_messages,
"max_tokens": 10,
"temperature": 0.1,
}
headers = {"Authorization": f"Bearer {model.api_key}"}
with httpx.Client(timeout=30.0) as client:
response = client.post(
f"{model.base_url}/chat/completions",
json=payload,
headers=headers
)
response.raise_for_status()
result_text = response.json()
latency_ms = int((time.time() - test_start) * 1000)
if result_text.get("choices"):
result.add_test("API Connection", True, f"Latency: {latency_ms}ms", latency_ms)
else:
result.add_test("API Connection", False, "Empty response", latency_ms)
except Exception as e:
latency_ms = int((time.time() - test_start) * 1000)
result.add_test("API Connection", False, str(e)[:200], latency_ms)
# 3. 检查模型配置
if model.temperature is not None:
result.add_test("Temperature Setting", True, f"temperature={model.temperature}")
else:
result.add_test("Temperature Setting", True, "Using default")
# 4. 测试流式响应(可选)
if model.type == "text":
test_start = time.time()
try:
with httpx.Client(timeout=30.0) as client:
with client.stream(
"POST",
f"{model.base_url}/chat/completions",
json={
"model": model.model_name or "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Count from 1 to 3."}],
"stream": True,
},
headers=headers
) as response:
response.raise_for_status()
chunk_count = 0
for _ in response.iter_bytes():
chunk_count += 1
latency_ms = int((time.time() - test_start) * 1000)
result.add_test("Streaming Support", True, f"Received {chunk_count} chunks", latency_ms)
except Exception as e:
latency_ms = int((time.time() - test_start) * 1000)
result.add_test("Streaming Support", False, str(e)[:200], latency_ms)
def _test_asr_model(db: Session, model_id: str, result: AutotestResult):
"""内部方法测试ASR模型"""
start_time = time.time()
# 1. 检查模型是否存在
model = db.query(ASRModel).filter(ASRModel.id == model_id).first()
duration_ms = int((time.time() - start_time) * 1000)
if not model:
result.add_test("Model Existence", False, f"Model {model_id} not found", duration_ms)
return
result.add_test("Model Existence", True, f"Found model: {model.name}", duration_ms)
# 2. 测试配置
if model.hotwords:
result.add_test("Hotwords Config", True, f"Hotwords: {len(model.hotwords)} words")
else:
result.add_test("Hotwords Config", True, "No hotwords configured")
# 3. 测试API可用性
test_start = time.time()
try:
headers = {"Authorization": f"Bearer {model.api_key}"}
with httpx.Client(timeout=30.0) as client:
normalized_vendor = (model.vendor or "").strip().lower()
if normalized_vendor in [
"openai compatible",
"openai-compatible",
"siliconflow", # backward compatibility
"paraformer",
]:
response = client.get(
f"{model.base_url}/asr",
headers=headers
)
elif model.vendor.lower() == "openai":
response = client.get(
f"{model.base_url}/audio/models",
headers=headers
)
else:
# 通用健康检查
response = client.get(
f"{model.base_url}/health",
headers=headers
)
latency_ms = int((time.time() - test_start) * 1000)
if response.status_code in [200, 405]: # 405 = method not allowed but endpoint exists
result.add_test("API Availability", True, f"Status: {response.status_code}", latency_ms)
else:
result.add_test("API Availability", False, f"Status: {response.status_code}", latency_ms)
except httpx.TimeoutException:
latency_ms = int((time.time() - test_start) * 1000)
result.add_test("API Availability", False, "Connection timeout", latency_ms)
except Exception as e:
latency_ms = int((time.time() - test_start) * 1000)
result.add_test("API Availability", False, str(e)[:200], latency_ms)
# 4. 检查语言配置
if model.language in ["zh", "en", "Multi-lingual"]:
result.add_test("Language Config", True, f"Language: {model.language}")
else:
result.add_test("Language Config", False, f"Unknown language: {model.language}")
# ============ Quick Health Check ============
@router.get("/health")
def health_check():
"""快速健康检查"""
return {
"status": "healthy",
"timestamp": time.time(),
"tools": list(TOOL_REGISTRY.keys())
}
@router.post("/test-message")
def send_test_message(
llm_model_id: str,
message: str = "Hello, this is a test message.",
db: Session = Depends(get_db)
):
"""发送测试消息"""
model = db.query(LLMModel).filter(LLMModel.id == llm_model_id).first()
if not model:
raise HTTPException(status_code=404, detail="LLM Model not found")
try:
payload = {
"model": model.model_name or "gpt-3.5-turbo",
"messages": [{"role": "user", "content": message}],
"max_tokens": 500,
"temperature": 0.7,
}
headers = {"Authorization": f"Bearer {model.api_key}"}
with httpx.Client(timeout=60.0) as client:
response = client.post(
f"{model.base_url}/chat/completions",
json=payload,
headers=headers
)
response.raise_for_status()
result = response.json()
reply = result.get("choices", [{}])[0].get("message", {}).get("content", "")
return {
"success": True,
"reply": reply,
"usage": result.get("usage", {})
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))