from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session from typing import Optional, Dict, Any, List import time import uuid import httpx from datetime import datetime from ..db import get_db from ..models import LLMModel, ASRModel, ToolResource from ..schemas import ToolResourceCreate, ToolResourceOut, ToolResourceUpdate router = APIRouter(prefix="/tools", tags=["Tools & Autotest"]) # ============ Available Tools ============ TOOL_REGISTRY = { "search": { "name": "网络搜索", "description": "搜索互联网获取最新信息", "parameters": { "type": "object", "properties": { "query": {"type": "string", "description": "搜索关键词"} }, "required": ["query"] } }, "calculator": { "name": "计算器", "description": "执行数学计算", "parameters": { "type": "object", "properties": { "expression": {"type": "string", "description": "数学表达式,如: 2 + 3 * 4"} }, "required": ["expression"] } }, "weather": { "name": "天气查询", "description": "查询指定城市的天气", "parameters": { "type": "object", "properties": { "city": {"type": "string", "description": "城市名称"} }, "required": ["city"] } }, "translate": { "name": "翻译", "description": "翻译文本到指定语言", "parameters": { "type": "object", "properties": { "text": {"type": "string", "description": "要翻译的文本"}, "target_lang": {"type": "string", "description": "目标语言,如: en, ja, ko"} }, "required": ["text", "target_lang"] } }, "knowledge": { "name": "知识库查询", "description": "从知识库中检索相关信息", "parameters": { "type": "object", "properties": { "query": {"type": "string", "description": "查询内容"}, "kb_id": {"type": "string", "description": "知识库ID"} }, "required": ["query"] } }, "code_interpreter": { "name": "代码执行", "description": "安全地执行Python代码", "parameters": { "type": "object", "properties": { "code": {"type": "string", "description": "要执行的Python代码"} }, "required": ["code"] } }, } TOOL_CATEGORY_MAP = { "search": "query", "weather": "query", "translate": "query", "knowledge": "query", "calculator": "system", "code_interpreter": "system", } TOOL_ICON_MAP = { "search": "Globe", "weather": "CloudSun", "translate": "Globe", "knowledge": "Box", "calculator": "Terminal", "code_interpreter": "Terminal", } def _seed_default_tools_if_empty(db: Session) -> None: """Seed default tools into DB when tool_resources is empty.""" if db.query(ToolResource).count() > 0: return for tool_id, payload in TOOL_REGISTRY.items(): db.add(ToolResource( id=tool_id, user_id=1, name=payload.get("name", tool_id), description=payload.get("description", ""), category=TOOL_CATEGORY_MAP.get(tool_id, "system"), icon=TOOL_ICON_MAP.get(tool_id, "Wrench"), enabled=True, is_system=True, )) db.commit() @router.get("/list") def list_available_tools(): """获取可用的工具列表""" return {"tools": TOOL_REGISTRY} @router.get("/list/{tool_id}") def get_tool_detail(tool_id: str): """获取工具详情""" if tool_id not in TOOL_REGISTRY: raise HTTPException(status_code=404, detail="Tool not found") return TOOL_REGISTRY[tool_id] # ============ Tool Resource CRUD ============ @router.get("/resources") def list_tool_resources( category: Optional[str] = None, enabled: Optional[bool] = None, include_system: bool = True, page: int = 1, limit: int = 100, db: Session = Depends(get_db), ): """获取工具资源列表。system/query 仅表示工具执行类型,不代表权限。""" _seed_default_tools_if_empty(db) query = db.query(ToolResource) if not include_system: query = query.filter(ToolResource.is_system == False) if category: query = query.filter(ToolResource.category == category) if enabled is not None: query = query.filter(ToolResource.enabled == enabled) total = query.count() rows = query.order_by(ToolResource.created_at.desc()).offset(max(page - 1, 0) * limit).limit(limit).all() return {"total": total, "page": page, "limit": limit, "list": rows} @router.get("/resources/{id}", response_model=ToolResourceOut) def get_tool_resource(id: str, db: Session = Depends(get_db)): """获取单个工具资源详情。""" _seed_default_tools_if_empty(db) item = db.query(ToolResource).filter(ToolResource.id == id).first() if not item: raise HTTPException(status_code=404, detail="Tool resource not found") return item @router.post("/resources", response_model=ToolResourceOut) def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db)): """创建自定义工具资源。""" _seed_default_tools_if_empty(db) candidate_id = (data.id or "").strip() if candidate_id and db.query(ToolResource).filter(ToolResource.id == candidate_id).first(): raise HTTPException(status_code=400, detail="Tool ID already exists") item = ToolResource( id=candidate_id or f"tool_{str(uuid.uuid4())[:8]}", user_id=1, name=data.name, description=data.description, category=data.category, icon=data.icon, enabled=data.enabled, is_system=False, ) db.add(item) db.commit() db.refresh(item) return item @router.put("/resources/{id}", response_model=ToolResourceOut) def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depends(get_db)): """更新工具资源。""" _seed_default_tools_if_empty(db) item = db.query(ToolResource).filter(ToolResource.id == id).first() if not item: raise HTTPException(status_code=404, detail="Tool resource not found") update_data = data.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(item, field, value) item.updated_at = datetime.utcnow() db.commit() db.refresh(item) return item @router.delete("/resources/{id}") def delete_tool_resource(id: str, db: Session = Depends(get_db)): """删除工具资源。""" _seed_default_tools_if_empty(db) item = db.query(ToolResource).filter(ToolResource.id == id).first() if not item: raise HTTPException(status_code=404, detail="Tool resource not found") db.delete(item) db.commit() return {"message": "Deleted successfully"} # ============ Autotest ============ class AutotestResult: """自动测试结果""" def __init__(self): self.id = str(uuid.uuid4())[:8] self.started_at = time.time() self.tests = [] self.summary = {"passed": 0, "failed": 0, "total": 0} def add_test(self, name: str, passed: bool, message: str = "", duration_ms: int = 0): self.tests.append({ "name": name, "passed": passed, "message": message, "duration_ms": duration_ms }) if passed: self.summary["passed"] += 1 else: self.summary["failed"] += 1 self.summary["total"] += 1 def to_dict(self): return { "id": self.id, "started_at": self.started_at, "duration_ms": int((time.time() - self.started_at) * 1000), "tests": self.tests, "summary": self.summary } @router.post("/autotest") def run_autotest( llm_model_id: Optional[str] = None, asr_model_id: Optional[str] = None, test_llm: bool = True, test_asr: bool = True, db: Session = Depends(get_db) ): """运行自动测试""" result = AutotestResult() # 测试 LLM 模型 if test_llm and llm_model_id: _test_llm_model(db, llm_model_id, result) # 测试 ASR 模型 if test_asr and asr_model_id: _test_asr_model(db, asr_model_id, result) # 测试 TTS 功能(需要时可添加) if test_llm and not llm_model_id: result.add_test( "LLM Model Check", False, "No LLM model ID provided" ) if test_asr and not asr_model_id: result.add_test( "ASR Model Check", False, "No ASR model ID provided" ) return result.to_dict() @router.post("/autotest/llm/{model_id}") def autotest_llm_model(model_id: str, db: Session = Depends(get_db)): """测试单个LLM模型""" result = AutotestResult() _test_llm_model(db, model_id, result) return result.to_dict() @router.post("/autotest/asr/{model_id}") def autotest_asr_model(model_id: str, db: Session = Depends(get_db)): """测试单个ASR模型""" result = AutotestResult() _test_asr_model(db, model_id, result) return result.to_dict() def _test_llm_model(db: Session, model_id: str, result: AutotestResult): """内部方法:测试LLM模型""" start_time = time.time() # 1. 检查模型是否存在 model = db.query(LLMModel).filter(LLMModel.id == model_id).first() duration_ms = int((time.time() - start_time) * 1000) if not model: result.add_test("Model Existence", False, f"Model {model_id} not found", duration_ms) return result.add_test("Model Existence", True, f"Found model: {model.name}", duration_ms) # 2. 测试连接 test_start = time.time() try: test_messages = [{"role": "user", "content": "Reply with 'OK'."}] payload = { "model": model.model_name or "gpt-3.5-turbo", "messages": test_messages, "max_tokens": 10, "temperature": 0.1, } headers = {"Authorization": f"Bearer {model.api_key}"} with httpx.Client(timeout=30.0) as client: response = client.post( f"{model.base_url}/chat/completions", json=payload, headers=headers ) response.raise_for_status() result_text = response.json() latency_ms = int((time.time() - test_start) * 1000) if result_text.get("choices"): result.add_test("API Connection", True, f"Latency: {latency_ms}ms", latency_ms) else: result.add_test("API Connection", False, "Empty response", latency_ms) except Exception as e: latency_ms = int((time.time() - test_start) * 1000) result.add_test("API Connection", False, str(e)[:200], latency_ms) # 3. 检查模型配置 if model.temperature is not None: result.add_test("Temperature Setting", True, f"temperature={model.temperature}") else: result.add_test("Temperature Setting", True, "Using default") # 4. 测试流式响应(可选) if model.type == "text": test_start = time.time() try: with httpx.Client(timeout=30.0) as client: with client.stream( "POST", f"{model.base_url}/chat/completions", json={ "model": model.model_name or "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Count from 1 to 3."}], "stream": True, }, headers=headers ) as response: response.raise_for_status() chunk_count = 0 for _ in response.iter_bytes(): chunk_count += 1 latency_ms = int((time.time() - test_start) * 1000) result.add_test("Streaming Support", True, f"Received {chunk_count} chunks", latency_ms) except Exception as e: latency_ms = int((time.time() - test_start) * 1000) result.add_test("Streaming Support", False, str(e)[:200], latency_ms) def _test_asr_model(db: Session, model_id: str, result: AutotestResult): """内部方法:测试ASR模型""" start_time = time.time() # 1. 检查模型是否存在 model = db.query(ASRModel).filter(ASRModel.id == model_id).first() duration_ms = int((time.time() - start_time) * 1000) if not model: result.add_test("Model Existence", False, f"Model {model_id} not found", duration_ms) return result.add_test("Model Existence", True, f"Found model: {model.name}", duration_ms) # 2. 测试配置 if model.hotwords: result.add_test("Hotwords Config", True, f"Hotwords: {len(model.hotwords)} words") else: result.add_test("Hotwords Config", True, "No hotwords configured") # 3. 测试API可用性 test_start = time.time() try: headers = {"Authorization": f"Bearer {model.api_key}"} with httpx.Client(timeout=30.0) as client: if model.vendor.lower() in ["siliconflow", "paraformer"]: response = client.get( f"{model.base_url}/asr", headers=headers ) elif model.vendor.lower() == "openai": response = client.get( f"{model.base_url}/audio/models", headers=headers ) else: # 通用健康检查 response = client.get( f"{model.base_url}/health", headers=headers ) latency_ms = int((time.time() - test_start) * 1000) if response.status_code in [200, 405]: # 405 = method not allowed but endpoint exists result.add_test("API Availability", True, f"Status: {response.status_code}", latency_ms) else: result.add_test("API Availability", False, f"Status: {response.status_code}", latency_ms) except httpx.TimeoutException: latency_ms = int((time.time() - test_start) * 1000) result.add_test("API Availability", False, "Connection timeout", latency_ms) except Exception as e: latency_ms = int((time.time() - test_start) * 1000) result.add_test("API Availability", False, str(e)[:200], latency_ms) # 4. 检查语言配置 if model.language in ["zh", "en", "Multi-lingual"]: result.add_test("Language Config", True, f"Language: {model.language}") else: result.add_test("Language Config", False, f"Unknown language: {model.language}") # ============ Quick Health Check ============ @router.get("/health") def health_check(): """快速健康检查""" return { "status": "healthy", "timestamp": time.time(), "tools": list(TOOL_REGISTRY.keys()) } @router.post("/test-message") def send_test_message( llm_model_id: str, message: str = "Hello, this is a test message.", db: Session = Depends(get_db) ): """发送测试消息""" model = db.query(LLMModel).filter(LLMModel.id == llm_model_id).first() if not model: raise HTTPException(status_code=404, detail="LLM Model not found") try: payload = { "model": model.model_name or "gpt-3.5-turbo", "messages": [{"role": "user", "content": message}], "max_tokens": 500, "temperature": 0.7, } headers = {"Authorization": f"Bearer {model.api_key}"} with httpx.Client(timeout=60.0) as client: response = client.post( f"{model.base_url}/chat/completions", json=payload, headers=headers ) response.raise_for_status() result = response.json() reply = result.get("choices", [{}])[0].get("message", {}).get("content", "") return { "success": True, "reply": reply, "usage": result.get("usage", {}) } except Exception as e: raise HTTPException(status_code=500, detail=str(e))