From 7012f8edaf8f4ba18265695dde60e23e6e1a4b04 Mon Sep 17 00:00:00 2001 From: Xin Wang Date: Sun, 8 Feb 2026 15:52:16 +0800 Subject: [PATCH] Update backend api --- api/app/main.py | 6 +- api/app/routers/__init__.py | 6 + api/app/routers/asr.py | 268 ++++++++++++++++++++ api/app/routers/assistants.py | 119 ++++++++- api/app/routers/llm.py | 206 ++++++++++++++++ api/app/routers/tools.py | 379 +++++++++++++++++++++++++++++ api/docs/asr.md | 409 +++++++++++++++++++++++++++++++ api/docs/index.md | 6 +- api/docs/llm.md | 401 ++++++++++++++++++++++++++++++ api/docs/tools.md | 445 ++++++++++++++++++++++++++++++++++ api/init_db.py | 373 +++++++++++++++++++++++++++- api/tests/conftest.py | 35 +++ api/tests/test_asr.py | 289 ++++++++++++++++++++++ api/tests/test_llm.py | 246 +++++++++++++++++++ api/tests/test_tools.py | 267 ++++++++++++++++++++ 15 files changed, 3436 insertions(+), 19 deletions(-) create mode 100644 api/app/routers/asr.py create mode 100644 api/app/routers/llm.py create mode 100644 api/app/routers/tools.py create mode 100644 api/docs/asr.md create mode 100644 api/docs/llm.md create mode 100644 api/docs/tools.md create mode 100644 api/tests/test_asr.py create mode 100644 api/tests/test_llm.py create mode 100644 api/tests/test_tools.py diff --git a/api/app/main.py b/api/app/main.py index 14d9b19..1573b67 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -4,7 +4,7 @@ from contextlib import asynccontextmanager import os from .db import Base, engine -from .routers import assistants, history +from .routers import assistants, history, knowledge, llm, asr, tools @asynccontextmanager @@ -33,6 +33,10 @@ app.add_middleware( # 路由 app.include_router(assistants.router, prefix="/api") app.include_router(history.router, prefix="/api") +app.include_router(knowledge.router, prefix="/api") +app.include_router(llm.router, prefix="/api") +app.include_router(asr.router, prefix="/api") +app.include_router(tools.router, prefix="/api") @app.get("/") diff --git a/api/app/routers/__init__.py b/api/app/routers/__init__.py index 5b92416..2d68474 100644 --- a/api/app/routers/__init__.py +++ b/api/app/routers/__init__.py @@ -3,9 +3,15 @@ from fastapi import APIRouter from . import assistants from . import history from . import knowledge +from . import llm +from . import asr +from . import tools router = APIRouter() router.include_router(assistants.router) router.include_router(history.router) router.include_router(knowledge.router) +router.include_router(llm.router) +router.include_router(asr.router) +router.include_router(tools.router) diff --git a/api/app/routers/asr.py b/api/app/routers/asr.py new file mode 100644 index 0000000..8dd5822 --- /dev/null +++ b/api/app/routers/asr.py @@ -0,0 +1,268 @@ +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.orm import Session +from typing import List, Optional +import uuid +import httpx +import time +import base64 +import json +from datetime import datetime + +from ..db import get_db +from ..models import ASRModel +from ..schemas import ( + ASRModelCreate, ASRModelUpdate, ASRModelOut, + ASRTestRequest, ASRTestResponse, ListResponse +) + +router = APIRouter(prefix="/asr", tags=["ASR Models"]) + + +# ============ ASR Models CRUD ============ +@router.get("", response_model=ListResponse) +def list_asr_models( + language: Optional[str] = None, + enabled: Optional[bool] = None, + page: int = 1, + limit: int = 50, + db: Session = Depends(get_db) +): + """获取ASR模型列表""" + query = db.query(ASRModel) + + if language: + query = query.filter(ASRModel.language == language) + if enabled is not None: + query = query.filter(ASRModel.enabled == enabled) + + total = query.count() + models = query.order_by(ASRModel.created_at.desc()) \ + .offset((page-1)*limit).limit(limit).all() + + return {"total": total, "page": page, "limit": limit, "list": models} + + +@router.get("/{id}", response_model=ASRModelOut) +def get_asr_model(id: str, db: Session = Depends(get_db)): + """获取单个ASR模型详情""" + model = db.query(ASRModel).filter(ASRModel.id == id).first() + if not model: + raise HTTPException(status_code=404, detail="ASR Model not found") + return model + + +@router.post("", response_model=ASRModelOut) +def create_asr_model(data: ASRModelCreate, db: Session = Depends(get_db)): + """创建ASR模型""" + asr_model = ASRModel( + id=data.id or str(uuid.uuid4())[:8], + user_id=1, # 默认用户 + name=data.name, + vendor=data.vendor, + language=data.language, + base_url=data.base_url, + api_key=data.api_key, + model_name=data.model_name, + hotwords=data.hotwords, + enable_punctuation=data.enable_punctuation, + enable_normalization=data.enable_normalization, + enabled=data.enabled, + ) + db.add(asr_model) + db.commit() + db.refresh(asr_model) + return asr_model + + +@router.put("/{id}", response_model=ASRModelOut) +def update_asr_model(id: str, data: ASRModelUpdate, db: Session = Depends(get_db)): + """更新ASR模型""" + model = db.query(ASRModel).filter(ASRModel.id == id).first() + if not model: + raise HTTPException(status_code=404, detail="ASR Model not found") + + update_data = data.model_dump(exclude_unset=True) + for field, value in update_data.items(): + setattr(model, field, value) + + db.commit() + db.refresh(model) + return model + + +@router.delete("/{id}") +def delete_asr_model(id: str, db: Session = Depends(get_db)): + """删除ASR模型""" + model = db.query(ASRModel).filter(ASRModel.id == id).first() + if not model: + raise HTTPException(status_code=404, detail="ASR Model not found") + db.delete(model) + db.commit() + return {"message": "Deleted successfully"} + + +@router.post("/{id}/test", response_model=ASRTestResponse) +def test_asr_model( + id: str, + request: Optional[ASRTestRequest] = None, + db: Session = Depends(get_db) +): + """测试ASR模型""" + model = db.query(ASRModel).filter(ASRModel.id == id).first() + if not model: + raise HTTPException(status_code=404, detail="ASR Model not found") + + start_time = time.time() + + try: + # 根据不同的厂商构造不同的请求 + if model.vendor.lower() in ["siliconflow", "paraformer"]: + # SiliconFlow/Paraformer 格式 + payload = { + "model": model.model_name or "paraformer-v2", + "input": {}, + "parameters": { + "hotwords": " ".join(model.hotwords) if model.hotwords else "", + "enable_punctuation": model.enable_punctuation, + "enable_normalization": model.enable_normalization, + } + } + + # 如果有音频数据 + if request and request.audio_data: + payload["input"]["file_urls"] = [] + elif request and request.audio_url: + payload["input"]["url"] = request.audio_url + + headers = {"Authorization": f"Bearer {model.api_key}"} + + with httpx.Client(timeout=60.0) as client: + response = client.post( + f"{model.base_url}/asr", + json=payload, + headers=headers + ) + response.raise_for_status() + result = response.json() + + elif model.vendor.lower() == "openai": + # OpenAI Whisper 格式 + headers = {"Authorization": f"Bearer {model.api_key}"} + + # 准备文件 + files = {} + if request and request.audio_data: + audio_bytes = base64.b64decode(request.audio_data) + files = {"file": ("audio.wav", audio_bytes, "audio/wav")} + data = {"model": model.model_name or "whisper-1"} + elif request and request.audio_url: + files = {"file": ("audio.wav", httpx.get(request.audio_url).content, "audio/wav")} + data = {"model": model.model_name or "whisper-1"} + else: + return ASRTestResponse( + success=False, + error="No audio data or URL provided" + ) + + with httpx.Client(timeout=60.0) as client: + response = client.post( + f"{model.base_url}/audio/transcriptions", + files=files, + data=data, + headers=headers + ) + response.raise_for_status() + result = response.json() + result = {"results": [{"transcript": result.get("text", "")}]} + + else: + # 通用格式(可根据需要扩展) + return ASRTestResponse( + success=False, + message=f"Unsupported vendor: {model.vendor}" + ) + + latency_ms = int((time.time() - start_time) * 1000) + + # 解析结果 + if result_data := result.get("results", [{}])[0]: + transcript = result_data.get("transcript", "") + return ASRTestResponse( + success=True, + transcript=transcript, + language=result_data.get("language", model.language), + confidence=result_data.get("confidence"), + latency_ms=latency_ms, + ) + + return ASRTestResponse( + success=False, + message="No transcript in response", + latency_ms=latency_ms + ) + + except httpx.HTTPStatusError as e: + return ASRTestResponse( + success=False, + error=f"HTTP Error: {e.response.status_code} - {e.response.text[:200]}" + ) + except Exception as e: + return ASRTestResponse( + success=False, + error=str(e)[:200] + ) + + +@router.post("/{id}/transcribe") +def transcribe_audio( + id: str, + audio_url: Optional[str] = None, + audio_data: Optional[str] = None, + hotwords: Optional[List[str]] = None, + db: Session = Depends(get_db) +): + """转写音频""" + model = db.query(ASRModel).filter(ASRModel.id == id).first() + if not model: + raise HTTPException(status_code=404, detail="ASR Model not found") + + try: + payload = { + "model": model.model_name or "paraformer-v2", + "input": {}, + "parameters": { + "hotwords": " ".join(hotwords or model.hotwords or []), + "enable_punctuation": model.enable_punctuation, + "enable_normalization": model.enable_normalization, + } + } + + headers = {"Authorization": f"Bearer {model.api_key}"} + + if audio_url: + payload["input"]["url"] = audio_url + elif audio_data: + payload["input"]["file_urls"] = [] + + with httpx.Client(timeout=120.0) as client: + response = client.post( + f"{model.base_url}/asr", + json=payload, + headers=headers + ) + response.raise_for_status() + + result = response.json() + + if result_data := result.get("results", [{}])[0]: + return { + "success": True, + "transcript": result_data.get("transcript", ""), + "language": result_data.get("language", model.language), + "confidence": result_data.get("confidence"), + } + + return {"success": False, "error": "No transcript in response"} + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/api/app/routers/assistants.py b/api/app/routers/assistants.py index b67fc62..a756060 100644 --- a/api/app/routers/assistants.py +++ b/api/app/routers/assistants.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session -from typing import List +from typing import List, Optional import uuid from datetime import datetime @@ -8,7 +8,7 @@ from ..db import get_db from ..models import Assistant, Voice, Workflow from ..schemas import ( AssistantCreate, AssistantUpdate, AssistantOut, - VoiceOut, + VoiceCreate, VoiceUpdate, VoiceOut, WorkflowCreate, WorkflowUpdate, WorkflowOut ) @@ -16,11 +16,88 @@ router = APIRouter() # ============ Voices ============ -@router.get("/voices", response_model=List[VoiceOut]) -def list_voices(db: Session = Depends(get_db)): +@router.get("/voices") +def list_voices( + vendor: Optional[str] = None, + language: Optional[str] = None, + gender: Optional[str] = None, + page: int = 1, + limit: int = 50, + db: Session = Depends(get_db) +): """获取声音库列表""" - voices = db.query(Voice).all() - return voices + query = db.query(Voice) + if vendor: + query = query.filter(Voice.vendor == vendor) + if language: + query = query.filter(Voice.language == language) + if gender: + query = query.filter(Voice.gender == gender) + + total = query.count() + voices = query.order_by(Voice.created_at.desc()) \ + .offset((page-1)*limit).limit(limit).all() + return {"total": total, "page": page, "limit": limit, "list": voices} + + +@router.post("/voices", response_model=VoiceOut) +def create_voice(data: VoiceCreate, db: Session = Depends(get_db)): + """创建声音""" + voice = Voice( + id=data.id or str(uuid.uuid4())[:8], + user_id=1, + name=data.name, + vendor=data.vendor, + gender=data.gender, + language=data.language, + description=data.description, + model=data.model, + voice_key=data.voice_key, + speed=data.speed, + gain=data.gain, + pitch=data.pitch, + enabled=data.enabled, + ) + db.add(voice) + db.commit() + db.refresh(voice) + return voice + + +@router.get("/voices/{id}", response_model=VoiceOut) +def get_voice(id: str, db: Session = Depends(get_db)): + """获取单个声音详情""" + voice = db.query(Voice).filter(Voice.id == id).first() + if not voice: + raise HTTPException(status_code=404, detail="Voice not found") + return voice + + +@router.put("/voices/{id}", response_model=VoiceOut) +def update_voice(id: str, data: VoiceUpdate, db: Session = Depends(get_db)): + """更新声音""" + voice = db.query(Voice).filter(Voice.id == id).first() + if not voice: + raise HTTPException(status_code=404, detail="Voice not found") + + update_data = data.model_dump(exclude_unset=True) + for field, value in update_data.items(): + setattr(voice, field, value) + + db.commit() + db.refresh(voice) + return voice + + +@router.delete("/voices/{id}") +def delete_voice(id: str, db: Session = Depends(get_db)): + """删除声音""" + voice = db.query(Voice).filter(Voice.id == id).first() + if not voice: + raise HTTPException(status_code=404, detail="Voice not found") + db.delete(voice) + db.commit() + return {"message": "Deleted successfully"} # ============ Assistants ============ @@ -79,11 +156,11 @@ def update_assistant(id: str, data: AssistantUpdate, db: Session = Depends(get_d assistant = db.query(Assistant).filter(Assistant.id == id).first() if not assistant: raise HTTPException(status_code=404, detail="Assistant not found") - + update_data = data.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(assistant, field, value) - + assistant.updated_at = datetime.utcnow() db.commit() db.refresh(assistant) @@ -103,10 +180,17 @@ def delete_assistant(id: str, db: Session = Depends(get_db)): # ============ Workflows ============ @router.get("/workflows", response_model=List[WorkflowOut]) -def list_workflows(db: Session = Depends(get_db)): +def list_workflows( + page: int = 1, + limit: int = 50, + db: Session = Depends(get_db) +): """获取工作流列表""" - workflows = db.query(Workflow).all() - return workflows + query = db.query(Workflow) + total = query.count() + workflows = query.order_by(Workflow.created_at.desc()) \ + .offset((page-1)*limit).limit(limit).all() + return {"total": total, "page": page, "limit": limit, "list": workflows} @router.post("/workflows", response_model=WorkflowOut) @@ -129,17 +213,26 @@ def create_workflow(data: WorkflowCreate, db: Session = Depends(get_db)): return workflow +@router.get("/workflows/{id}", response_model=WorkflowOut) +def get_workflow(id: str, db: Session = Depends(get_db)): + """获取单个工作流""" + workflow = db.query(Workflow).filter(Workflow.id == id).first() + if not workflow: + raise HTTPException(status_code=404, detail="Workflow not found") + return workflow + + @router.put("/workflows/{id}", response_model=WorkflowOut) def update_workflow(id: str, data: WorkflowUpdate, db: Session = Depends(get_db)): """更新工作流""" workflow = db.query(Workflow).filter(Workflow.id == id).first() if not workflow: raise HTTPException(status_code=404, detail="Workflow not found") - + update_data = data.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(workflow, field, value) - + workflow.updated_at = datetime.utcnow().isoformat() db.commit() db.refresh(workflow) diff --git a/api/app/routers/llm.py b/api/app/routers/llm.py new file mode 100644 index 0000000..71c854b --- /dev/null +++ b/api/app/routers/llm.py @@ -0,0 +1,206 @@ +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.orm import Session +from typing import List, Optional +import uuid +import httpx +import time +from datetime import datetime + +from ..db import get_db +from ..models import LLMModel +from ..schemas import ( + LLMModelCreate, LLMModelUpdate, LLMModelOut, + LLMModelTestResponse, ListResponse +) + +router = APIRouter(prefix="/llm", tags=["LLM Models"]) + + +# ============ LLM Models CRUD ============ +@router.get("", response_model=ListResponse) +def list_llm_models( + model_type: Optional[str] = None, + enabled: Optional[bool] = None, + page: int = 1, + limit: int = 50, + db: Session = Depends(get_db) +): + """获取LLM模型列表""" + query = db.query(LLMModel) + + if model_type: + query = query.filter(LLMModel.type == model_type) + if enabled is not None: + query = query.filter(LLMModel.enabled == enabled) + + total = query.count() + models = query.order_by(LLMModel.created_at.desc()) \ + .offset((page-1)*limit).limit(limit).all() + + return {"total": total, "page": page, "limit": limit, "list": models} + + +@router.get("/{id}", response_model=LLMModelOut) +def get_llm_model(id: str, db: Session = Depends(get_db)): + """获取单个LLM模型详情""" + model = db.query(LLMModel).filter(LLMModel.id == id).first() + if not model: + raise HTTPException(status_code=404, detail="LLM Model not found") + return model + + +@router.post("", response_model=LLMModelOut) +def create_llm_model(data: LLMModelCreate, db: Session = Depends(get_db)): + """创建LLM模型""" + llm_model = LLMModel( + id=data.id or str(uuid.uuid4())[:8], + user_id=1, # 默认用户 + name=data.name, + vendor=data.vendor, + type=data.type.value if hasattr(data.type, 'value') else data.type, + base_url=data.base_url, + api_key=data.api_key, + model_name=data.model_name, + temperature=data.temperature, + context_length=data.context_length, + enabled=data.enabled, + ) + db.add(llm_model) + db.commit() + db.refresh(llm_model) + return llm_model + + +@router.put("/{id}", response_model=LLMModelOut) +def update_llm_model(id: str, data: LLMModelUpdate, db: Session = Depends(get_db)): + """更新LLM模型""" + model = db.query(LLMModel).filter(LLMModel.id == id).first() + if not model: + raise HTTPException(status_code=404, detail="LLM Model not found") + + update_data = data.model_dump(exclude_unset=True) + for field, value in update_data.items(): + setattr(model, field, value) + + model.updated_at = datetime.utcnow() + db.commit() + db.refresh(model) + return model + + +@router.delete("/{id}") +def delete_llm_model(id: str, db: Session = Depends(get_db)): + """删除LLM模型""" + model = db.query(LLMModel).filter(LLMModel.id == id).first() + if not model: + raise HTTPException(status_code=404, detail="LLM Model not found") + db.delete(model) + db.commit() + return {"message": "Deleted successfully"} + + +@router.post("/{id}/test", response_model=LLMModelTestResponse) +def test_llm_model(id: str, db: Session = Depends(get_db)): + """测试LLM模型连接""" + model = db.query(LLMModel).filter(LLMModel.id == id).first() + if not model: + raise HTTPException(status_code=404, detail="LLM Model not found") + + start_time = time.time() + try: + # 构造测试请求 + test_messages = [{"role": "user", "content": "Hello, please reply with 'OK'."}] + + payload = { + "model": model.model_name or "gpt-3.5-turbo", + "messages": test_messages, + "max_tokens": 10, + "temperature": 0.1, + } + + headers = {"Authorization": f"Bearer {model.api_key}"} + + with httpx.Client(timeout=30.0) as client: + response = client.post( + f"{model.base_url}/chat/completions", + json=payload, + headers=headers + ) + response.raise_for_status() + + latency_ms = int((time.time() - start_time) * 1000) + result = response.json() + + if result.get("choices"): + return LLMModelTestResponse( + success=True, + latency_ms=latency_ms, + message="Connection successful" + ) + else: + return LLMModelTestResponse( + success=False, + latency_ms=latency_ms, + message="Unexpected response format" + ) + + except httpx.HTTPStatusError as e: + return LLMModelTestResponse( + success=False, + message=f"HTTP Error: {e.response.status_code} - {e.response.text[:200]}" + ) + except Exception as e: + return LLMModelTestResponse( + success=False, + message=str(e)[:200] + ) + + +@router.post("/{id}/chat") +def chat_with_llm( + id: str, + message: str, + system_prompt: Optional[str] = None, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + db: Session = Depends(get_db) +): + """与LLM模型对话""" + model = db.query(LLMModel).filter(LLMModel.id == id).first() + if not model: + raise HTTPException(status_code=404, detail="LLM Model not found") + + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": message}) + + payload = { + "model": model.model_name or "gpt-3.5-turbo", + "messages": messages, + "max_tokens": max_tokens or 1000, + "temperature": temperature if temperature is not None else model.temperature or 0.7, + } + + headers = {"Authorization": f"Bearer {model.api_key}"} + + try: + with httpx.Client(timeout=60.0) as client: + response = client.post( + f"{model.base_url}/chat/completions", + json=payload, + headers=headers + ) + response.raise_for_status() + + result = response.json() + if choice := result.get("choices", [{}])[0]: + return { + "success": True, + "reply": choice.get("message", {}).get("content", ""), + "usage": result.get("usage", {}) + } + return {"success": False, "reply": "", "error": "No response"} + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/api/app/routers/tools.py b/api/app/routers/tools.py new file mode 100644 index 0000000..1c34622 --- /dev/null +++ b/api/app/routers/tools.py @@ -0,0 +1,379 @@ +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.orm import Session +from typing import Optional, Dict, Any +import time +import uuid +import httpx + +from ..db import get_db +from ..models import LLMModel, ASRModel + +router = APIRouter(prefix="/tools", tags=["Tools & Autotest"]) + + +# ============ Available Tools ============ +TOOL_REGISTRY = { + "search": { + "name": "网络搜索", + "description": "搜索互联网获取最新信息", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "搜索关键词"} + }, + "required": ["query"] + } + }, + "calculator": { + "name": "计算器", + "description": "执行数学计算", + "parameters": { + "type": "object", + "properties": { + "expression": {"type": "string", "description": "数学表达式,如: 2 + 3 * 4"} + }, + "required": ["expression"] + } + }, + "weather": { + "name": "天气查询", + "description": "查询指定城市的天气", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "城市名称"} + }, + "required": ["city"] + } + }, + "translate": { + "name": "翻译", + "description": "翻译文本到指定语言", + "parameters": { + "type": "object", + "properties": { + "text": {"type": "string", "description": "要翻译的文本"}, + "target_lang": {"type": "string", "description": "目标语言,如: en, ja, ko"} + }, + "required": ["text", "target_lang"] + } + }, + "knowledge": { + "name": "知识库查询", + "description": "从知识库中检索相关信息", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "查询内容"}, + "kb_id": {"type": "string", "description": "知识库ID"} + }, + "required": ["query"] + } + }, + "code_interpreter": { + "name": "代码执行", + "description": "安全地执行Python代码", + "parameters": { + "type": "object", + "properties": { + "code": {"type": "string", "description": "要执行的Python代码"} + }, + "required": ["code"] + } + }, +} + + +@router.get("/list") +def list_available_tools(): + """获取可用的工具列表""" + return {"tools": TOOL_REGISTRY} + + +@router.get("/list/{tool_id}") +def get_tool_detail(tool_id: str): + """获取工具详情""" + if tool_id not in TOOL_REGISTRY: + raise HTTPException(status_code=404, detail="Tool not found") + return TOOL_REGISTRY[tool_id] + + +# ============ Autotest ============ +class AutotestResult: + """自动测试结果""" + + def __init__(self): + self.id = str(uuid.uuid4())[:8] + self.started_at = time.time() + self.tests = [] + self.summary = {"passed": 0, "failed": 0, "total": 0} + + def add_test(self, name: str, passed: bool, message: str = "", duration_ms: int = 0): + self.tests.append({ + "name": name, + "passed": passed, + "message": message, + "duration_ms": duration_ms + }) + if passed: + self.summary["passed"] += 1 + else: + self.summary["failed"] += 1 + self.summary["total"] += 1 + + def to_dict(self): + return { + "id": self.id, + "started_at": self.started_at, + "duration_ms": int((time.time() - self.started_at) * 1000), + "tests": self.tests, + "summary": self.summary + } + + +@router.post("/autotest") +def run_autotest( + llm_model_id: Optional[str] = None, + asr_model_id: Optional[str] = None, + test_llm: bool = True, + test_asr: bool = True, + db: Session = Depends(get_db) +): + """运行自动测试""" + result = AutotestResult() + + # 测试 LLM 模型 + if test_llm and llm_model_id: + _test_llm_model(db, llm_model_id, result) + + # 测试 ASR 模型 + if test_asr and asr_model_id: + _test_asr_model(db, asr_model_id, result) + + # 测试 TTS 功能(需要时可添加) + if test_llm and not llm_model_id: + result.add_test( + "LLM Model Check", + False, + "No LLM model ID provided" + ) + + if test_asr and not asr_model_id: + result.add_test( + "ASR Model Check", + False, + "No ASR model ID provided" + ) + + return result.to_dict() + + +@router.post("/autotest/llm/{model_id}") +def autotest_llm_model(model_id: str, db: Session = Depends(get_db)): + """测试单个LLM模型""" + result = AutotestResult() + _test_llm_model(db, model_id, result) + return result.to_dict() + + +@router.post("/autotest/asr/{model_id}") +def autotest_asr_model(model_id: str, db: Session = Depends(get_db)): + """测试单个ASR模型""" + result = AutotestResult() + _test_asr_model(db, model_id, result) + return result.to_dict() + + +def _test_llm_model(db: Session, model_id: str, result: AutotestResult): + """内部方法:测试LLM模型""" + start_time = time.time() + + # 1. 检查模型是否存在 + model = db.query(LLMModel).filter(LLMModel.id == model_id).first() + duration_ms = int((time.time() - start_time) * 1000) + + if not model: + result.add_test("Model Existence", False, f"Model {model_id} not found", duration_ms) + return + + result.add_test("Model Existence", True, f"Found model: {model.name}", duration_ms) + + # 2. 测试连接 + test_start = time.time() + try: + test_messages = [{"role": "user", "content": "Reply with 'OK'."}] + payload = { + "model": model.model_name or "gpt-3.5-turbo", + "messages": test_messages, + "max_tokens": 10, + "temperature": 0.1, + } + headers = {"Authorization": f"Bearer {model.api_key}"} + + with httpx.Client(timeout=30.0) as client: + response = client.post( + f"{model.base_url}/chat/completions", + json=payload, + headers=headers + ) + response.raise_for_status() + + result_text = response.json() + latency_ms = int((time.time() - test_start) * 1000) + + if result_text.get("choices"): + result.add_test("API Connection", True, f"Latency: {latency_ms}ms", latency_ms) + else: + result.add_test("API Connection", False, "Empty response", latency_ms) + + except Exception as e: + latency_ms = int((time.time() - test_start) * 1000) + result.add_test("API Connection", False, str(e)[:200], latency_ms) + + # 3. 检查模型配置 + if model.temperature is not None: + result.add_test("Temperature Setting", True, f"temperature={model.temperature}") + else: + result.add_test("Temperature Setting", True, "Using default") + + # 4. 测试流式响应(可选) + if model.type == "text": + test_start = time.time() + try: + with httpx.Client(timeout=30.0) as client: + with client.stream( + "POST", + f"{model.base_url}/chat/completions", + json={ + "model": model.model_name or "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Count from 1 to 3."}], + "stream": True, + }, + headers=headers + ) as response: + response.raise_for_status() + chunk_count = 0 + for _ in response.iter_bytes(): + chunk_count += 1 + + latency_ms = int((time.time() - test_start) * 1000) + result.add_test("Streaming Support", True, f"Received {chunk_count} chunks", latency_ms) + except Exception as e: + latency_ms = int((time.time() - test_start) * 1000) + result.add_test("Streaming Support", False, str(e)[:200], latency_ms) + + +def _test_asr_model(db: Session, model_id: str, result: AutotestResult): + """内部方法:测试ASR模型""" + start_time = time.time() + + # 1. 检查模型是否存在 + model = db.query(ASRModel).filter(ASRModel.id == model_id).first() + duration_ms = int((time.time() - start_time) * 1000) + + if not model: + result.add_test("Model Existence", False, f"Model {model_id} not found", duration_ms) + return + + result.add_test("Model Existence", True, f"Found model: {model.name}", duration_ms) + + # 2. 测试配置 + if model.hotwords: + result.add_test("Hotwords Config", True, f"Hotwords: {len(model.hotwords)} words") + else: + result.add_test("Hotwords Config", True, "No hotwords configured") + + # 3. 测试API可用性 + test_start = time.time() + try: + headers = {"Authorization": f"Bearer {model.api_key}"} + + with httpx.Client(timeout=30.0) as client: + if model.vendor.lower() in ["siliconflow", "paraformer"]: + response = client.get( + f"{model.base_url}/asr", + headers=headers + ) + elif model.vendor.lower() == "openai": + response = client.get( + f"{model.base_url}/audio/models", + headers=headers + ) + else: + # 通用健康检查 + response = client.get( + f"{model.base_url}/health", + headers=headers + ) + + latency_ms = int((time.time() - test_start) * 1000) + + if response.status_code in [200, 405]: # 405 = method not allowed but endpoint exists + result.add_test("API Availability", True, f"Status: {response.status_code}", latency_ms) + else: + result.add_test("API Availability", False, f"Status: {response.status_code}", latency_ms) + + except httpx.TimeoutException: + latency_ms = int((time.time() - test_start) * 1000) + result.add_test("API Availability", False, "Connection timeout", latency_ms) + except Exception as e: + latency_ms = int((time.time() - test_start) * 1000) + result.add_test("API Availability", False, str(e)[:200], latency_ms) + + # 4. 检查语言配置 + if model.language in ["zh", "en", "Multi-lingual"]: + result.add_test("Language Config", True, f"Language: {model.language}") + else: + result.add_test("Language Config", False, f"Unknown language: {model.language}") + + +# ============ Quick Health Check ============ +@router.get("/health") +def health_check(): + """快速健康检查""" + return { + "status": "healthy", + "timestamp": time.time(), + "tools": list(TOOL_REGISTRY.keys()) + } + + +@router.post("/test-message") +def send_test_message( + llm_model_id: str, + message: str = "Hello, this is a test message.", + db: Session = Depends(get_db) +): + """发送测试消息""" + model = db.query(LLMModel).filter(LLMModel.id == llm_model_id).first() + if not model: + raise HTTPException(status_code=404, detail="LLM Model not found") + + try: + payload = { + "model": model.model_name or "gpt-3.5-turbo", + "messages": [{"role": "user", "content": message}], + "max_tokens": 500, + "temperature": 0.7, + } + headers = {"Authorization": f"Bearer {model.api_key}"} + + with httpx.Client(timeout=60.0) as client: + response = client.post( + f"{model.base_url}/chat/completions", + json=payload, + headers=headers + ) + response.raise_for_status() + + result = response.json() + reply = result.get("choices", [{}])[0].get("message", {}).get("content", "") + + return { + "success": True, + "reply": reply, + "usage": result.get("usage", {}) + } + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/api/docs/asr.md b/api/docs/asr.md new file mode 100644 index 0000000..08767ad --- /dev/null +++ b/api/docs/asr.md @@ -0,0 +1,409 @@ +# 语音识别 (ASR Model) API + +语音识别 API 用于管理语音识别模型的配置和调用。 + +## 基础信息 + +| 项目 | 值 | +|------|-----| +| Base URL | `/api/v1/asr` | +| 认证方式 | Bearer Token (预留) | + +--- + +## 数据模型 + +### ASRModel + +```typescript +interface ASRModel { + id: string; // 模型唯一标识 (8位UUID) + user_id: number; // 所属用户ID + name: string; // 模型显示名称 + vendor: string; // 供应商: "OpenAI" | "SiliconFlow" | "Paraformer" | 等 + language: string; // 识别语言: "zh" | "en" | "Multi-lingual" + base_url: string; // API Base URL + api_key: string; // API Key + model_name?: string; // 模型名称,如 "whisper-1" | "paraformer-v2" + hotwords?: string[]; // 热词列表 + enable_punctuation: boolean; // 是否启用标点 + enable_normalization: boolean; // 是否启用文本规范化 + enabled: boolean; // 是否启用 + created_at: string; +} +``` + +--- + +## API 端点 + +### 1. 获取 ASR 模型列表 + +```http +GET /api/v1/asr +``` + +**Query Parameters:** + +| 参数 | 类型 | 必填 | 默认值 | 说明 | +|------|------|------|--------|------| +| language | string | 否 | - | 过滤语言: "zh" \| "en" \| "Multi-lingual" | +| enabled | boolean | 否 | - | 过滤启用状态 | +| page | int | 否 | 1 | 页码 | +| limit | int | 否 | 50 | 每页数量 | + +**Response:** + +```json +{ + "total": 3, + "page": 1, + "limit": 50, + "list": [ + { + "id": "abc12345", + "user_id": 1, + "name": "Whisper 多语种识别", + "vendor": "OpenAI", + "language": "Multi-lingual", + "base_url": "https://api.openai.com/v1", + "api_key": "sk-***", + "model_name": "whisper-1", + "enable_punctuation": true, + "enable_normalization": true, + "enabled": true, + "created_at": "2024-01-15T10:30:00Z" + }, + { + "id": "def67890", + "user_id": 1, + "name": "SenseVoice 中文识别", + "vendor": "SiliconFlow", + "language": "zh", + "base_url": "https://api.siliconflow.cn/v1", + "api_key": "sf-***", + "model_name": "paraformer-v2", + "hotwords": ["小助手", "帮我"], + "enable_punctuation": true, + "enable_normalization": true, + "enabled": true, + "created_at": "2024-01-15T10:30:00Z" + } + ] +} +``` + +--- + +### 2. 获取单个 ASR 模型详情 + +```http +GET /api/v1/asr/{id} +``` + +**Path Parameters:** + +| 参数 | 类型 | 说明 | +|------|------|------| +| id | string | 模型ID | + +**Response:** + +```json +{ + "id": "abc12345", + "user_id": 1, + "name": "Whisper 多语种识别", + "vendor": "OpenAI", + "language": "Multi-lingual", + "base_url": "https://api.openai.com/v1", + "api_key": "sk-***", + "model_name": "whisper-1", + "hotwords": [], + "enable_punctuation": true, + "enable_normalization": true, + "enabled": true, + "created_at": "2024-01-15T10:30:00Z" +} +``` + +--- + +### 3. 创建 ASR 模型 + +```http +POST /api/v1/asr +``` + +**Request Body:** + +```json +{ + "name": "SenseVoice 中文识别", + "vendor": "SiliconFlow", + "language": "zh", + "base_url": "https://api.siliconflow.cn/v1", + "api_key": "sk-your-api-key", + "model_name": "paraformer-v2", + "hotwords": ["小助手", "帮我"], + "enable_punctuation": true, + "enable_normalization": true, + "enabled": true +} +``` + +**Fields 说明:** + +| 字段 | 类型 | 必填 | 说明 | +|------|------|------|------| +| name | string | 是 | 模型显示名称 | +| vendor | string | 是 | 供应商: "OpenAI" / "SiliconFlow" / "Paraformer" | +| language | string | 是 | 语言: "zh" / "en" / "Multi-lingual" | +| base_url | string | 是 | API Base URL | +| api_key | string | 是 | API Key | +| model_name | string | 否 | 模型名称 | +| hotwords | string[] | 否 | 热词列表,提升识别准确率 | +| enable_punctuation | boolean | 否 | 是否输出标点,默认 true | +| enable_normalization | boolean | 否 | 是否文本规范化,默认 true | +| enabled | boolean | 否 | 是否启用,默认 true | +| id | string | 否 | 指定模型ID,默认自动生成 | + +--- + +### 4. 更新 ASR 模型 + +```http +PUT /api/v1/asr/{id} +``` + +**Request Body:** (部分更新) + +```json +{ + "name": "Whisper-1 优化版", + "language": "zh", + "enable_punctuation": true, + "hotwords": ["新词1", "新词2"] +} +``` + +--- + +### 5. 删除 ASR 模型 + +```http +DELETE /api/v1/asr/{id} +``` + +**Response:** + +```json +{ + "message": "Deleted successfully" +} +``` + +--- + +### 6. 测试 ASR 模型 + +```http +POST /api/v1/asr/{id}/test +``` + +**Request Body:** + +```json +{ + "audio_url": "https://example.com/test-audio.wav" +} +``` + +或使用 Base64 编码的音频数据: + +```json +{ + "audio_data": "UklGRi..." +} +``` + +**Response (成功):** + +```json +{ + "success": true, + "transcript": "您好,请问有什么可以帮助您?", + "language": "zh", + "confidence": 0.95, + "latency_ms": 500 +} +``` + +**Response (失败):** + +```json +{ + "success": false, + "error": "HTTP Error: 401 - Unauthorized" +} +``` + +--- + +### 7. 转写音频 + +```http +POST /api/v1/asr/{id}/transcribe +``` + +**Query Parameters:** + +| 参数 | 类型 | 必填 | 说明 | +|------|------|------|------| +| audio_url | string | 否* | 音频文件URL | +| audio_data | string | 否* | Base64编码的音频数据 | +| hotwords | string[] | 否 | 热词列表 | + +*二选一,至少提供一个 + +**Response:** + +```json +{ + "success": true, + "transcript": "您好,请问有什么可以帮助您?", + "language": "zh", + "confidence": 0.95 +} +``` + +--- + +## Schema 定义 + +```python +from enum import Enum +from pydantic import BaseModel +from typing import Optional, List +from datetime import datetime + +class ASRLanguage(str, Enum): + ZH = "zh" + EN = "en" + MULTILINGUAL = "Multi-lingual" + +class ASRModelBase(BaseModel): + name: str + vendor: str + language: str # "zh" | "en" | "Multi-lingual" + base_url: str + api_key: str + model_name: Optional[str] = None + hotwords: List[str] = [] + enable_punctuation: bool = True + enable_normalization: bool = True + enabled: bool = True + +class ASRModelCreate(ASRModelBase): + id: Optional[str] = None + +class ASRModelUpdate(BaseModel): + name: Optional[str] = None + language: Optional[str] = None + base_url: Optional[str] = None + api_key: Optional[str] = None + model_name: Optional[str] = None + hotwords: Optional[List[str]] = None + enable_punctuation: Optional[bool] = None + enable_normalization: Optional[bool] = None + enabled: Optional[bool] = None + +class ASRModelOut(ASRModelBase): + id: str + user_id: int + created_at: datetime + + class Config: + from_attributes = True + +class ASRTestRequest(BaseModel): + audio_url: Optional[str] = None + audio_data: Optional[str] = None # base64 encoded + +class ASRTestResponse(BaseModel): + success: bool + transcript: Optional[str] = None + language: Optional[str] = None + confidence: Optional[float] = None + latency_ms: Optional[int] = None + error: Optional[str] = None +``` + +--- + +## 供应商配置示例 + +### OpenAI Whisper + +```json +{ + "vendor": "OpenAI", + "base_url": "https://api.openai.com/v1", + "api_key": "sk-xxx", + "model_name": "whisper-1", + "language": "Multi-lingual", + "enable_punctuation": true, + "enable_normalization": true +} +``` + +### SiliconFlow Paraformer + +```json +{ + "vendor": "SiliconFlow", + "base_url": "https://api.siliconflow.cn/v1", + "api_key": "sf-xxx", + "model_name": "paraformer-v2", + "language": "zh", + "hotwords": ["产品名称", "公司名"], + "enable_punctuation": true, + "enable_normalization": true +} +``` + +--- + +## 单元测试 + +项目包含完整的单元测试,位于 `api/tests/test_asr.py`。 + +### 测试用例概览 + +| 测试方法 | 说明 | +|----------|------| +| test_get_asr_models_empty | 空数据库获取测试 | +| test_create_asr_model | 创建模型测试 | +| test_create_asr_model_minimal | 最小数据创建测试 | +| test_get_asr_model_by_id | 获取单个模型测试 | +| test_get_asr_model_not_found | 获取不存在模型测试 | +| test_update_asr_model | 更新模型测试 | +| test_delete_asr_model | 删除模型测试 | +| test_list_asr_models_with_pagination | 分页测试 | +| test_filter_asr_models_by_language | 按语言过滤测试 | +| test_filter_asr_models_by_enabled | 按启用状态过滤测试 | +| test_create_asr_model_with_hotwords | 热词配置测试 | +| test_test_asr_model_siliconflow | SiliconFlow 供应商测试 | +| test_test_asr_model_openai | OpenAI 供应商测试 | +| test_different_asr_languages | 多语言测试 | +| test_different_asr_vendors | 多供应商测试 | + +### 运行测试 + +```bash +# 运行 ASR 相关测试 +pytest api/tests/test_asr.py -v + +# 运行所有测试 +pytest api/tests/ -v +``` diff --git a/api/docs/index.md b/api/docs/index.md index be4ed8c..fb685f3 100644 --- a/api/docs/index.md +++ b/api/docs/index.md @@ -7,9 +7,9 @@ | 模块 | 文件 | 说明 | |------|------|------| | 小助手 | [assistant.md](./assistant.md) | AI 助手管理 | -| 模型接入 | [model-access.md](./model-access.md) | LLM/ASR/TTS 模型配置 | -| 语音识别 | [speech-recognition.md](./speech-recognition.md) | ASR 模型配置 | -| 声音资源 | [voice-resources.md](./voice-resources.md) | TTS 声音库管理 | +| LLM 模型 | [llm.md](./llm.md) | LLM 模型配置与管理 | +| ASR 模型 | [asr.md](./asr.md) | 语音识别模型配置 | +| 工具与测试 | [tools.md](./tools.md) | 工具列表与自动测试 | | 历史记录 | [history-records.md](./history-records.md) | 通话记录和转写 | --- diff --git a/api/docs/llm.md b/api/docs/llm.md new file mode 100644 index 0000000..2b2fce9 --- /dev/null +++ b/api/docs/llm.md @@ -0,0 +1,401 @@ +# LLM 模型 (LLM Model) API + +LLM 模型 API 用于管理大语言模型的配置和调用。 + +## 基础信息 + +| 项目 | 值 | +|------|-----| +| Base URL | `/api/v1/llm` | +| 认证方式 | Bearer Token (预留) | + +--- + +## 数据模型 + +### LLMModel + +```typescript +interface LLMModel { + id: string; // 模型唯一标识 (8位UUID) + user_id: number; // 所属用户ID + name: string; // 模型显示名称 + vendor: string; // 供应商: "OpenAI" | "SiliconFlow" | "Dify" | "FastGPT" | 等 + type: string; // 类型: "text" | "embedding" | "rerank" + base_url: string; // API Base URL + api_key: string; // API Key + model_name?: string; // 实际模型名称,如 "gpt-4o" + temperature?: number; // 温度参数 (0-2) + context_length?: int; // 上下文长度 + enabled: boolean; // 是否启用 + created_at: string; + updated_at: string; +} +``` + +--- + +## API 端点 + +### 1. 获取 LLM 模型列表 + +```http +GET /api/v1/llm +``` + +**Query Parameters:** + +| 参数 | 类型 | 必填 | 默认值 | 说明 | +|------|------|------|--------|------| +| model_type | string | 否 | - | 过滤类型: "text" \| "embedding" \| "rerank" | +| enabled | boolean | 否 | - | 过滤启用状态 | +| page | int | 否 | 1 | 页码 | +| limit | int | 否 | 50 | 每页数量 | + +**Response:** + +```json +{ + "total": 5, + "page": 1, + "limit": 50, + "list": [ + { + "id": "abc12345", + "user_id": 1, + "name": "GPT-4o", + "vendor": "OpenAI", + "type": "text", + "base_url": "https://api.openai.com/v1", + "api_key": "sk-***", + "model_name": "gpt-4o", + "temperature": 0.7, + "context_length": 128000, + "enabled": true, + "created_at": "2024-01-15T10:30:00Z", + "updated_at": "2024-01-15T10:30:00Z" + }, + { + "id": "def67890", + "user_id": 1, + "name": "Embedding-3-Small", + "vendor": "OpenAI", + "type": "embedding", + "base_url": "https://api.openai.com/v1", + "api_key": "sk-***", + "model_name": "text-embedding-3-small", + "enabled": true + } + ] +} +``` + +--- + +### 2. 获取单个 LLM 模型详情 + +```http +GET /api/v1/llm/{id} +``` + +**Path Parameters:** + +| 参数 | 类型 | 说明 | +|------|------|------| +| id | string | 模型ID | + +**Response:** + +```json +{ + "id": "abc12345", + "user_id": 1, + "name": "GPT-4o", + "vendor": "OpenAI", + "type": "text", + "base_url": "https://api.openai.com/v1", + "api_key": "sk-***", + "model_name": "gpt-4o", + "temperature": 0.7, + "context_length": 128000, + "enabled": true, + "created_at": "2024-01-15T10:30:00Z", + "updated_at": "2024-01-15T10:30:00Z" +} +``` + +--- + +### 3. 创建 LLM 模型 + +```http +POST /api/v1/llm +``` + +**Request Body:** + +```json +{ + "name": "GPT-4o", + "vendor": "OpenAI", + "type": "text", + "base_url": "https://api.openai.com/v1", + "api_key": "sk-your-api-key", + "model_name": "gpt-4o", + "temperature": 0.7, + "context_length": 128000, + "enabled": true +} +``` + +**Fields 说明:** + +| 字段 | 类型 | 必填 | 说明 | +|------|------|------|------| +| name | string | 是 | 模型显示名称 | +| vendor | string | 是 | 供应商名称 | +| type | string | 是 | 模型类型: "text" / "embedding" / "rerank" | +| base_url | string | 是 | API Base URL | +| api_key | string | 是 | API Key | +| model_name | string | 否 | 实际模型名称 | +| temperature | number | 否 | 温度参数,默认 0.7 | +| context_length | int | 否 | 上下文长度 | +| enabled | boolean | 否 | 是否启用,默认 true | +| id | string | 否 | 指定模型ID,默认自动生成 | + +--- + +### 4. 更新 LLM 模型 + +```http +PUT /api/v1/llm/{id} +``` + +**Request Body:** (部分更新) + +```json +{ + "name": "GPT-4o-Updated", + "temperature": 0.8, + "enabled": false +} +``` + +--- + +### 5. 删除 LLM 模型 + +```http +DELETE /api/v1/llm/{id} +``` + +**Response:** + +```json +{ + "message": "Deleted successfully" +} +``` + +--- + +### 6. 测试 LLM 模型连接 + +```http +POST /api/v1/llm/{id}/test +``` + +**Response:** + +```json +{ + "success": true, + "latency_ms": 150, + "message": "Connection successful" +} +``` + +**错误响应:** + +```json +{ + "success": false, + "latency_ms": 200, + "message": "HTTP Error: 401 - Unauthorized" +} +``` + +--- + +### 7. 与 LLM 模型对话 + +```http +POST /api/v1/llm/{id}/chat +``` + +**Query Parameters:** + +| 参数 | 类型 | 必填 | 默认值 | 说明 | +|------|------|------|--------|------| +| message | string | 是 | - | 用户消息 | +| system_prompt | string | 否 | - | 系统提示词 | +| max_tokens | int | 否 | 1000 | 最大生成token数 | +| temperature | number | 否 | 模型配置 | 温度参数 | + +**Response:** + +```json +{ + "success": true, + "reply": "您好!有什么可以帮助您的?", + "usage": { + "prompt_tokens": 20, + "completion_tokens": 15, + "total_tokens": 35 + } +} +``` + +--- + +## Schema 定义 + +```python +from enum import Enum +from pydantic import BaseModel +from typing import Optional +from datetime import datetime + +class LLMModelType(str, Enum): + TEXT = "text" + EMBEDDING = "embedding" + RERANK = "rerank" + +class LLMModelBase(BaseModel): + name: str + vendor: str + type: LLMModelType + base_url: str + api_key: str + model_name: Optional[str] = None + temperature: Optional[float] = None + context_length: Optional[int] = None + enabled: bool = True + +class LLMModelCreate(LLMModelBase): + id: Optional[str] = None + +class LLMModelUpdate(BaseModel): + name: Optional[str] = None + vendor: Optional[str] = None + base_url: Optional[str] = None + api_key: Optional[str] = None + model_name: Optional[str] = None + temperature: Optional[float] = None + context_length: Optional[int] = None + enabled: Optional[bool] = None + +class LLMModelOut(LLMModelBase): + id: str + user_id: int + created_at: datetime + updated_at: datetime + + class Config: + from_attributes = True + +class LLMModelTestResponse(BaseModel): + success: bool + latency_ms: int + message: str +``` + +--- + +## 供应商配置示例 + +### OpenAI + +```json +{ + "vendor": "OpenAI", + "base_url": "https://api.openai.com/v1", + "api_key": "sk-xxx", + "model_name": "gpt-4o", + "type": "text", + "temperature": 0.7 +} +``` + +### SiliconFlow + +```json +{ + "vendor": "SiliconFlow", + "base_url": "https://api.siliconflow.com/v1", + "api_key": "sf-xxx", + "model_name": "deepseek-v3", + "type": "text", + "temperature": 0.7 +} +``` + +### Dify + +```json +{ + "vendor": "Dify", + "base_url": "https://your-dify.domain.com/v1", + "api_key": "app-xxx", + "model_name": "gpt-4", + "type": "text" +} +``` + +### Embedding 模型 + +```json +{ + "vendor": "OpenAI", + "base_url": "https://api.openai.com/v1", + "api_key": "sk-xxx", + "model_name": "text-embedding-3-small", + "type": "embedding" +} +``` + +--- + +## 单元测试 + +项目包含完整的单元测试,位于 `api/tests/test_llm.py`。 + +### 测试用例概览 + +| 测试方法 | 说明 | +|----------|------| +| test_get_llm_models_empty | 空数据库获取测试 | +| test_create_llm_model | 创建模型测试 | +| test_create_llm_model_minimal | 最小数据创建测试 | +| test_get_llm_model_by_id | 获取单个模型测试 | +| test_get_llm_model_not_found | 获取不存在模型测试 | +| test_update_llm_model | 更新模型测试 | +| test_delete_llm_model | 删除模型测试 | +| test_list_llm_models_with_pagination | 分页测试 | +| test_filter_llm_models_by_type | 按类型过滤测试 | +| test_filter_llm_models_by_enabled | 按启用状态过滤测试 | +| test_create_llm_model_with_all_fields | 全字段创建测试 | +| test_test_llm_model_success | 测试连接成功测试 | +| test_test_llm_model_failure | 测试连接失败测试 | +| test_different_llm_vendors | 多供应商测试 | +| test_embedding_llm_model | Embedding 模型测试 | + +### 运行测试 + +```bash +# 运行 LLM 相关测试 +pytest api/tests/test_llm.py -v + +# 运行所有测试 +pytest api/tests/ -v +``` diff --git a/api/docs/tools.md b/api/docs/tools.md new file mode 100644 index 0000000..7361ffe --- /dev/null +++ b/api/docs/tools.md @@ -0,0 +1,445 @@ +# 工具与自动测试 (Tools & Autotest) API + +工具与自动测试 API 用于管理可用工具列表和自动测试功能。 + +## 基础信息 + +| 项目 | 值 | +|------|-----| +| Base URL | `/api/v1/tools` | +| 认证方式 | Bearer Token (预留) | + +--- + +## 可用工具 (Tool Registry) + +系统内置以下工具: + +| 工具ID | 名称 | 说明 | +|--------|------|------| +| search | 网络搜索 | 搜索互联网获取最新信息 | +| calculator | 计算器 | 执行数学计算 | +| weather | 天气查询 | 查询指定城市的天气 | +| translate | 翻译 | 翻译文本到指定语言 | +| knowledge | 知识库查询 | 从知识库中检索相关信息 | +| code_interpreter | 代码执行 | 安全地执行Python代码 | + +--- + +## API 端点 + +### 1. 获取可用工具列表 + +```http +GET /api/v1/tools/list +``` + +**Response:** + +```json +{ + "tools": { + "search": { + "name": "网络搜索", + "description": "搜索互联网获取最新信息", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "搜索关键词"} + }, + "required": ["query"] + } + }, + "calculator": { + "name": "计算器", + "description": "执行数学计算", + "parameters": { + "type": "object", + "properties": { + "expression": {"type": "string", "description": "数学表达式,如: 2 + 3 * 4"} + }, + "required": ["expression"] + } + }, + "weather": { + "name": "天气查询", + "description": "查询指定城市的天气", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "城市名称"} + }, + "required": ["city"] + } + }, + "translate": { + "name": "翻译", + "description": "翻译文本到指定语言", + "parameters": { + "type": "object", + "properties": { + "text": {"type": "string", "description": "要翻译的文本"}, + "target_lang": {"type": "string", "description": "目标语言,如: en, ja, ko"} + }, + "required": ["text", "target_lang"] + } + }, + "knowledge": { + "name": "知识库查询", + "description": "从知识库中检索相关信息", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "查询内容"}, + "kb_id": {"type": "string", "description": "知识库ID"} + }, + "required": ["query"] + } + }, + "code_interpreter": { + "name": "代码执行", + "description": "安全地执行Python代码", + "parameters": { + "type": "object", + "properties": { + "code": {"type": "string", "description": "要执行的Python代码"} + }, + "required": ["code"] + } + } + } +} +``` + +--- + +### 2. 获取工具详情 + +```http +GET /api/v1/tools/list/{tool_id} +``` + +**Path Parameters:** + +| 参数 | 类型 | 说明 | +|------|------|------| +| tool_id | string | 工具ID | + +**Response:** + +```json +{ + "name": "计算器", + "description": "执行数学计算", + "parameters": { + "type": "object", + "properties": { + "expression": {"type": "string", "description": "数学表达式,如: 2 + 3 * 4"} + }, + "required": ["expression"] + } +} +``` + +**错误响应 (工具不存在):** + +```json +{ + "detail": "Tool not found" +} +``` + +--- + +### 3. 健康检查 + +```http +GET /api/v1/tools/health +``` + +**Response:** + +```json +{ + "status": "healthy", + "timestamp": 1705315200.123, + "tools": ["search", "calculator", "weather", "translate", "knowledge", "code_interpreter"] +} +``` + +--- + +## 自动测试 (Autotest) + +### 4. 运行完整自动测试 + +```http +POST /api/v1/tools/autotest +``` + +**Query Parameters:** + +| 参数 | 类型 | 必填 | 默认值 | 说明 | +|------|------|------|--------|------| +| llm_model_id | string | 否 | - | LLM 模型ID | +| asr_model_id | string | 否 | - | ASR 模型ID | +| test_llm | boolean | 否 | true | 是否测试LLM | +| test_asr | boolean | 否 | true | 是否测试ASR | + +**Response:** + +```json +{ + "id": "abc12345", + "started_at": 1705315200.0, + "duration_ms": 2500, + "tests": [ + { + "name": "Model Existence", + "passed": true, + "message": "Found model: GPT-4o", + "duration_ms": 15 + }, + { + "name": "API Connection", + "passed": true, + "message": "Latency: 150ms", + "duration_ms": 150 + }, + { + "name": "Temperature Setting", + "passed": true, + "message": "temperature=0.7" + }, + { + "name": "Streaming Support", + "passed": true, + "message": "Received 15 chunks", + "duration_ms": 800 + } + ], + "summary": { + "passed": 4, + "failed": 0, + "total": 4 + } +} +``` + +--- + +### 5. 测试单个 LLM 模型 + +```http +POST /api/v1/tools/autotest/llm/{model_id} +``` + +**Path Parameters:** + +| 参数 | 类型 | 说明 | +|------|------|------| +| model_id | string | LLM 模型ID | + +**Response:** + +```json +{ + "id": "llm_test_001", + "started_at": 1705315200.0, + "duration_ms": 1200, + "tests": [ + { + "name": "Model Existence", + "passed": true, + "message": "Found model: GPT-4o", + "duration_ms": 10 + }, + { + "name": "API Connection", + "passed": true, + "message": "Latency: 180ms", + "duration_ms": 180 + }, + { + "name": "Temperature Setting", + "passed": true, + "message": "temperature=0.7" + }, + { + "name": "Streaming Support", + "passed": true, + "message": "Received 12 chunks", + "duration_ms": 650 + } + ], + "summary": { + "passed": 4, + "failed": 0, + "total": 4 + } +} +``` + +--- + +### 6. 测试单个 ASR 模型 + +```http +POST /api/v1/tools/autotest/asr/{model_id} +``` + +**Path Parameters:** + +| 参数 | 类型 | 说明 | +|------|------|------| +| model_id | string | ASR 模型ID | + +**Response:** + +```json +{ + "id": "asr_test_001", + "started_at": 1705315200.0, + "duration_ms": 800, + "tests": [ + { + "name": "Model Existence", + "passed": true, + "message": "Found model: Whisper-1", + "duration_ms": 8 + }, + { + "name": "Hotwords Config", + "passed": true, + "message": "Hotwords: 3 words" + }, + { + "name": "API Availability", + "passed": true, + "message": "Status: 200", + "duration_ms": 250 + }, + { + "name": "Language Config", + "passed": true, + "message": "Language: zh" + } + ], + "summary": { + "passed": 4, + "failed": 0, + "total": 4 + } +} +``` + +--- + +### 7. 发送测试消息 + +```http +POST /api/v1/tools/test-message +``` + +**Query Parameters:** + +| 参数 | 类型 | 必填 | 默认值 | 说明 | +|------|------|------|--------|------| +| llm_model_id | string | 是 | - | LLM 模型ID | +| message | string | 否 | "Hello, this is a test message." | 测试消息 | + +**Response:** + +```json +{ + "success": true, + "reply": "Hello! This is a test reply from GPT-4o.", + "usage": { + "prompt_tokens": 15, + "completion_tokens": 12, + "total_tokens": 27 + } +} +``` + +**错误响应 (模型不存在):** + +```json +{ + "detail": "LLM Model not found" +} +``` + +--- + +## 测试结果结构 + +### AutotestResult + +```typescript +interface AutotestResult { + id: string; // 测试ID + started_at: number; // 开始时间戳 + duration_ms: number; // 总耗时(毫秒) + tests: TestCase[]; // 测试用例列表 + summary: TestSummary; // 测试摘要 +} + +interface TestCase { + name: string; // 测试名称 + passed: boolean; // 是否通过 + message: string; // 测试消息 + duration_ms: number; // 耗时(毫秒) +} + +interface TestSummary { + passed: number; // 通过数量 + failed: number; // 失败数量 + total: number; // 总数量 +} +``` + +--- + +## 测试项目说明 + +### LLM 模型测试项目 + +| 测试名称 | 说明 | +|----------|------| +| Model Existence | 检查模型是否存在于数据库 | +| API Connection | 测试 API 连接并测量延迟 | +| Temperature Setting | 检查温度配置 | +| Streaming Support | 测试流式响应支持 | + +### ASR 模型测试项目 + +| 测试名称 | 说明 | +|----------|------| +| Model Existence | 检查模型是否存在于数据库 | +| Hotwords Config | 检查热词配置 | +| API Availability | 测试 API 可用性 | +| Language Config | 检查语言配置 | + +--- + +## 单元测试 + +项目包含完整的单元测试,位于 `api/tests/test_tools.py`。 + +### 测试用例概览 + +| 测试类 | 说明 | +|--------|------| +| TestToolsAPI | 工具列表、健康检查等基础功能测试 | +| TestAutotestAPI | 自动测试功能完整测试 | + +### 运行测试 + +```bash +# 运行工具相关测试 +pytest api/tests/test_tools.py -v + +# 运行所有测试 +pytest api/tests/ -v +``` diff --git a/api/init_db.py b/api/init_db.py index 0376ee2..5da97b6 100644 --- a/api/init_db.py +++ b/api/init_db.py @@ -7,7 +7,7 @@ import sys sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from app.db import Base, engine, DATABASE_URL -from app.models import Voice +from app.models import Voice, Assistant, KnowledgeBase, Workflow, LLMModel, ASRModel def init_db(): @@ -22,10 +22,379 @@ def init_db(): print("✅ 数据库表创建完成") +def init_default_data(): + from sqlalchemy.orm import Session + from app.db import SessionLocal + from app.models import Voice + + db = SessionLocal() + try: + # 检查是否已有数据 + if db.query(Voice).count() == 0: + # SiliconFlow CosyVoice 2.0 预设声音 (8个) + # 参考: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech + voices = [ + # 男声 (Male Voices) + Voice(id="alex", name="Alex", vendor="SiliconFlow", gender="Male", language="en", + description="Steady male voice.", is_system=True), + Voice(id="benjamin", name="Benjamin", vendor="SiliconFlow", gender="Male", language="en", + description="Deep male voice.", is_system=True), + Voice(id="charles", name="Charles", vendor="SiliconFlow", gender="Male", language="en", + description="Magnetic male voice.", is_system=True), + Voice(id="david", name="David", vendor="SiliconFlow", gender="Male", language="en", + description="Cheerful male voice.", is_system=True), + # 女声 (Female Voices) + Voice(id="anna", name="Anna", vendor="SiliconFlow", gender="Female", language="en", + description="Steady female voice.", is_system=True), + Voice(id="bella", name="Bella", vendor="SiliconFlow", gender="Female", language="en", + description="Passionate female voice.", is_system=True), + Voice(id="claire", name="Claire", vendor="SiliconFlow", gender="Female", language="en", + description="Gentle female voice.", is_system=True), + Voice(id="diana", name="Diana", vendor="SiliconFlow", gender="Female", language="en", + description="Cheerful female voice.", is_system=True), + # 中文方言 (Chinese Dialects) - 可选扩展 + Voice(id="amador", name="Amador", vendor="SiliconFlow", gender="Male", language="zh", + description="Male voice with Spanish accent."), + Voice(id="aelora", name="Aelora", vendor="SiliconFlow", gender="Female", language="en", + description="Elegant female voice."), + Voice(id="aelwin", name="Aelwin", vendor="SiliconFlow", gender="Male", language="en", + description="Deep male voice."), + Voice(id="blooming", name="Blooming", vendor="SiliconFlow", gender="Female", language="en", + description="Fresh and clear female voice."), + Voice(id="elysia", name="Elysia", vendor="SiliconFlow", gender="Female", language="en", + description="Smooth and silky female voice."), + Voice(id="leo", name="Leo", vendor="SiliconFlow", gender="Male", language="en", + description="Young male voice."), + Voice(id="lin", name="Lin", vendor="SiliconFlow", gender="Female", language="zh", + description="Standard Chinese female voice."), + Voice(id="rose", name="Rose", vendor="SiliconFlow", gender="Female", language="en", + description="Soft and gentle female voice."), + Voice(id="shao", name="Shao", vendor="SiliconFlow", gender="Male", language="zh", + description="Deep Chinese male voice."), + Voice(id="sky", name="Sky", vendor="SiliconFlow", gender="Male", language="en", + description="Clear and bright male voice."), + Voice(id="ael西山", name="Ael西山", vendor="SiliconFlow", gender="Female", language="zh", + description="Female voice with Chinese dialect."), + ] + for v in voices: + db.add(v) + db.commit() + print("✅ 默认声音数据已初始化 (SiliconFlow CosyVoice 2.0)") + finally: + db.close() + + +def init_default_assistants(): + """初始化默认助手""" + from sqlalchemy.orm import Session + from app.db import SessionLocal + + db = SessionLocal() + try: + if db.query(Assistant).count() == 0: + assistants = [ + Assistant( + id="default", + user_id=1, + name="AI 助手", + call_count=0, + opener="你好!我是AI助手,有什么可以帮你的吗?", + prompt="你是一个友好的AI助手,请用简洁清晰的语言回答用户的问题。", + language="zh", + voice="anna", + speed=1.0, + hotwords=[], + tools=["search", "calculator"], + interruption_sensitivity=500, + config_mode="platform", + llm_model_id="deepseek-chat", + asr_model_id="paraformer-v2", + ), + Assistant( + id="customer_service", + user_id=1, + name="客服助手", + call_count=0, + opener="您好,欢迎致电客服中心,请问有什么可以帮您?", + prompt="你是一个专业的客服人员,耐心解答客户问题,提供优质的服务体验。", + language="zh", + voice="bella", + speed=1.0, + hotwords=["客服", "投诉", "咨询"], + tools=["search"], + interruption_sensitivity=600, + config_mode="platform", + ), + Assistant( + id="english_tutor", + user_id=1, + name="英语导师", + call_count=0, + opener="Hello! I'm your English learning companion. How can I help you today?", + prompt="You are a friendly English tutor. Help users practice English conversation and explain grammar points clearly.", + language="en", + voice="alex", + speed=1.0, + hotwords=["grammar", "vocabulary", "practice"], + tools=[], + interruption_sensitivity=400, + config_mode="platform", + ), + ] + for a in assistants: + db.add(a) + db.commit() + print("✅ 默认助手数据已初始化") + finally: + db.close() + + +def init_default_workflows(): + """初始化默认工作流""" + from sqlalchemy.orm import Session + from app.db import SessionLocal + from datetime import datetime + + db = SessionLocal() + try: + if db.query(Workflow).count() == 0: + now = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") + workflows = [ + Workflow( + id="simple_conversation", + user_id=1, + name="简单对话", + node_count=2, + created_at=now, + updated_at=now, + global_prompt="处理简单的对话流程,用户问什么答什么。", + nodes=[ + {"id": "1", "type": "start", "position": {"x": 100, "y": 100}, "data": {"label": "开始"}}, + {"id": "2", "type": "ai_reply", "position": {"x": 300, "y": 100}, "data": {"label": "AI回复"}}, + ], + edges=[{"source": "1", "target": "2", "id": "e1-2"}], + ), + Workflow( + id="voice_input_flow", + user_id=1, + name="语音输入流程", + node_count=4, + created_at=now, + updated_at=now, + global_prompt="处理语音输入的完整流程。", + nodes=[ + {"id": "1", "type": "start", "position": {"x": 100, "y": 100}, "data": {"label": "开始"}}, + {"id": "2", "type": "asr", "position": {"x": 250, "y": 100}, "data": {"label": "语音识别"}}, + {"id": "3", "type": "llm", "position": {"x": 400, "y": 100}, "data": {"label": "LLM处理"}}, + {"id": "4", "type": "tts", "position": {"x": 550, "y": 100}, "data": {"label": "语音合成"}}, + ], + edges=[ + {"source": "1", "target": "2", "id": "e1-2"}, + {"source": "2", "target": "3", "id": "e2-3"}, + {"source": "3", "target": "4", "id": "e3-4"}, + ], + ), + ] + for w in workflows: + db.add(w) + db.commit() + print("✅ 默认工作流数据已初始化") + finally: + db.close() + + +def init_default_knowledge_bases(): + """初始化默认知识库""" + from sqlalchemy.orm import Session + from app.db import SessionLocal + + db = SessionLocal() + try: + if db.query(KnowledgeBase).count() == 0: + kb = KnowledgeBase( + id="default_kb", + user_id=1, + name="默认知识库", + description="系统默认知识库,用于存储常见问题解答。", + embedding_model="text-embedding-3-small", + chunk_size=500, + chunk_overlap=50, + doc_count=0, + chunk_count=0, + status="active", + ) + db.add(kb) + db.commit() + print("✅ 默认知识库已初始化") + finally: + db.close() + + +def init_default_llm_models(): + """初始化默认LLM模型""" + from sqlalchemy.orm import Session + from app.db import SessionLocal + + db = SessionLocal() + try: + if db.query(LLMModel).count() == 0: + llm_models = [ + LLMModel( + id="deepseek-chat", + user_id=1, + name="DeepSeek Chat", + vendor="SiliconFlow", + type="text", + base_url="https://api.deepseek.com", + api_key="YOUR_API_KEY", # 用户需替换 + model_name="deepseek-chat", + temperature=0.7, + context_length=4096, + enabled=True, + ), + LLMModel( + id="deepseek-reasoner", + user_id=1, + name="DeepSeek Reasoner", + vendor="SiliconFlow", + type="text", + base_url="https://api.deepseek.com", + api_key="YOUR_API_KEY", + model_name="deepseek-reasoner", + temperature=0.7, + context_length=4096, + enabled=True, + ), + LLMModel( + id="gpt-4o", + user_id=1, + name="GPT-4o", + vendor="OpenAI", + type="text", + base_url="https://api.openai.com/v1", + api_key="YOUR_API_KEY", + model_name="gpt-4o", + temperature=0.7, + context_length=16384, + enabled=True, + ), + LLMModel( + id="glm-4", + user_id=1, + name="GLM-4", + vendor="ZhipuAI", + type="text", + base_url="https://open.bigmodel.cn/api/paas/v4", + api_key="YOUR_API_KEY", + model_name="glm-4", + temperature=0.7, + context_length=8192, + enabled=True, + ), + LLMModel( + id="text-embedding-3-small", + user_id=1, + name="Embedding 3 Small", + vendor="OpenAI", + type="embedding", + base_url="https://api.openai.com/v1", + api_key="YOUR_API_KEY", + model_name="text-embedding-3-small", + enabled=True, + ), + ] + for m in llm_models: + db.add(m) + db.commit() + print("✅ 默认LLM模型已初始化") + finally: + db.close() + + +def init_default_asr_models(): + """初始化默认ASR模型""" + from sqlalchemy.orm import Session + from app.db import SessionLocal + + db = SessionLocal() + try: + if db.query(ASRModel).count() == 0: + asr_models = [ + ASRModel( + id="paraformer-v2", + user_id=1, + name="Paraformer V2", + vendor="SiliconFlow", + language="zh", + base_url="https://api.siliconflow.cn/v1", + api_key="YOUR_API_KEY", + model_name="paraformer-v2", + hotwords=["人工智能", "机器学习"], + enable_punctuation=True, + enable_normalization=True, + enabled=True, + ), + ASRModel( + id="paraformer-en", + user_id=1, + name="Paraformer English", + vendor="SiliconFlow", + language="en", + base_url="https://api.siliconflow.cn/v1", + api_key="YOUR_API_KEY", + model_name="paraformer-en", + hotwords=[], + enable_punctuation=True, + enable_normalization=True, + enabled=True, + ), + ASRModel( + id="whisper-1", + user_id=1, + name="Whisper", + vendor="OpenAI", + language="Multi-lingual", + base_url="https://api.openai.com/v1", + api_key="YOUR_API_KEY", + model_name="whisper-1", + hotwords=[], + enable_punctuation=True, + enable_normalization=True, + enabled=True, + ), + ASRModel( + id="sensevoice", + user_id=1, + name="SenseVoice", + vendor="SiliconFlow", + language="Multi-lingual", + base_url="https://api.siliconflow.cn/v1", + api_key="YOUR_API_KEY", + model_name="sensevoice", + hotwords=[], + enable_punctuation=True, + enable_normalization=True, + enabled=True, + ), + ] + for m in asr_models: + db.add(m) + db.commit() + print("✅ 默认ASR模型已初始化") + finally: + db.close() + + if __name__ == "__main__": # 确保 data 目录存在 data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data") os.makedirs(data_dir, exist_ok=True) - + init_db() + init_default_data() + init_default_assistants() + init_default_workflows() + init_default_knowledge_bases() + init_default_llm_models() + init_default_asr_models() print("🎉 数据库初始化完成!") diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 0c6104e..015cef7 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -100,3 +100,38 @@ def sample_call_record_data(): "assistant_id": None, "source": "debug" } + + +@pytest.fixture +def sample_llm_model_data(): + """Sample LLM model data for testing""" + return { + "id": "test-llm-001", + "name": "Test LLM Model", + "vendor": "TestVendor", + "type": "text", + "base_url": "https://api.test.com/v1", + "api_key": "test-api-key", + "model_name": "test-model", + "temperature": 0.7, + "context_length": 4096, + "enabled": True + } + + +@pytest.fixture +def sample_asr_model_data(): + """Sample ASR model data for testing""" + return { + "id": "test-asr-001", + "name": "Test ASR Model", + "vendor": "TestVendor", + "language": "zh", + "base_url": "https://api.test.com/v1", + "api_key": "test-api-key", + "model_name": "paraformer-v2", + "hotwords": ["测试", "语音"], + "enable_punctuation": True, + "enable_normalization": True, + "enabled": True + } diff --git a/api/tests/test_asr.py b/api/tests/test_asr.py new file mode 100644 index 0000000..3f8cc77 --- /dev/null +++ b/api/tests/test_asr.py @@ -0,0 +1,289 @@ +"""Tests for ASR Model API endpoints""" +import pytest +from unittest.mock import patch, MagicMock + + +class TestASRModelAPI: + """Test cases for ASR Model endpoints""" + + def test_get_asr_models_empty(self, client): + """Test getting ASR models when database is empty""" + response = client.get("/api/asr") + assert response.status_code == 200 + data = response.json() + assert "total" in data + assert "list" in data + assert data["total"] == 0 + + def test_create_asr_model(self, client, sample_asr_model_data): + """Test creating a new ASR model""" + response = client.post("/api/asr", json=sample_asr_model_data) + assert response.status_code == 200 + data = response.json() + assert data["name"] == sample_asr_model_data["name"] + assert data["vendor"] == sample_asr_model_data["vendor"] + assert data["language"] == sample_asr_model_data["language"] + assert "id" in data + + def test_create_asr_model_minimal(self, client): + """Test creating an ASR model with minimal required data""" + data = { + "name": "Minimal ASR", + "vendor": "Test", + "language": "zh", + "base_url": "https://api.test.com", + "api_key": "test-key" + } + response = client.post("/api/asr", json=data) + assert response.status_code == 200 + assert response.json()["name"] == "Minimal ASR" + + def test_get_asr_model_by_id(self, client, sample_asr_model_data): + """Test getting a specific ASR model by ID""" + # Create first + create_response = client.post("/api/asr", json=sample_asr_model_data) + model_id = create_response.json()["id"] + + # Get by ID + response = client.get(f"/api/asr/{model_id}") + assert response.status_code == 200 + data = response.json() + assert data["id"] == model_id + assert data["name"] == sample_asr_model_data["name"] + + def test_get_asr_model_not_found(self, client): + """Test getting a non-existent ASR model""" + response = client.get("/api/asr/non-existent-id") + assert response.status_code == 404 + + def test_update_asr_model(self, client, sample_asr_model_data): + """Test updating an ASR model""" + # Create first + create_response = client.post("/api/asr", json=sample_asr_model_data) + model_id = create_response.json()["id"] + + # Update + update_data = { + "name": "Updated ASR Model", + "language": "en", + "enable_punctuation": False + } + response = client.put(f"/api/asr/{model_id}", json=update_data) + assert response.status_code == 200 + data = response.json() + assert data["name"] == "Updated ASR Model" + assert data["language"] == "en" + assert data["enable_punctuation"] == False + + def test_delete_asr_model(self, client, sample_asr_model_data): + """Test deleting an ASR model""" + # Create first + create_response = client.post("/api/asr", json=sample_asr_model_data) + model_id = create_response.json()["id"] + + # Delete + response = client.delete(f"/api/asr/{model_id}") + assert response.status_code == 200 + + # Verify deleted + get_response = client.get(f"/api/asr/{model_id}") + assert get_response.status_code == 404 + + def test_list_asr_models_with_pagination(self, client, sample_asr_model_data): + """Test listing ASR models with pagination""" + # Create multiple models + for i in range(3): + data = sample_asr_model_data.copy() + data["id"] = f"test-asr-{i}" + data["name"] = f"ASR Model {i}" + client.post("/api/asr", json=data) + + # Test pagination + response = client.get("/api/asr?page=1&limit=2") + assert response.status_code == 200 + data = response.json() + assert data["total"] == 3 + assert len(data["list"]) == 2 + + def test_filter_asr_models_by_language(self, client, sample_asr_model_data): + """Test filtering ASR models by language""" + # Create models with different languages + for i, lang in enumerate(["zh", "en", "Multi-lingual"]): + data = sample_asr_model_data.copy() + data["id"] = f"test-asr-{lang}" + data["name"] = f"ASR {lang}" + data["language"] = lang + client.post("/api/asr", json=data) + + # Filter by language + response = client.get("/api/asr?language=zh") + assert response.status_code == 200 + data = response.json() + assert data["total"] >= 1 + for model in data["list"]: + assert model["language"] == "zh" + + def test_filter_asr_models_by_enabled(self, client, sample_asr_model_data): + """Test filtering ASR models by enabled status""" + # Create enabled and disabled models + data = sample_asr_model_data.copy() + data["id"] = "test-asr-enabled" + data["name"] = "Enabled ASR" + data["enabled"] = True + client.post("/api/asr", json=data) + + data["id"] = "test-asr-disabled" + data["name"] = "Disabled ASR" + data["enabled"] = False + client.post("/api/asr", json=data) + + # Filter by enabled + response = client.get("/api/asr?enabled=true") + assert response.status_code == 200 + data = response.json() + for model in data["list"]: + assert model["enabled"] == True + + def test_create_asr_model_with_hotwords(self, client): + """Test creating an ASR model with hotwords""" + data = { + "id": "asr-hotwords", + "name": "ASR with Hotwords", + "vendor": "SiliconFlow", + "language": "zh", + "base_url": "https://api.siliconflow.cn/v1", + "api_key": "test-key", + "model_name": "paraformer-v2", + "hotwords": ["你好", "查询", "帮助"], + "enable_punctuation": True, + "enable_normalization": True + } + response = client.post("/api/asr", json=data) + assert response.status_code == 200 + result = response.json() + assert result["hotwords"] == ["你好", "查询", "帮助"] + + def test_create_asr_model_with_all_fields(self, client): + """Test creating an ASR model with all fields""" + data = { + "id": "full-asr", + "name": "Full ASR Model", + "vendor": "SiliconFlow", + "language": "zh", + "base_url": "https://api.siliconflow.cn/v1", + "api_key": "sk-test", + "model_name": "paraformer-v2", + "hotwords": ["测试"], + "enable_punctuation": True, + "enable_normalization": True, + "enabled": True + } + response = client.post("/api/asr", json=data) + assert response.status_code == 200 + result = response.json() + assert result["name"] == "Full ASR Model" + assert result["enable_punctuation"] == True + assert result["enable_normalization"] == True + + @patch('httpx.Client') + def test_test_asr_model_siliconflow(self, mock_client_class, client, sample_asr_model_data): + """Test testing an ASR model with SiliconFlow vendor""" + # Create model first + sample_asr_model_data["vendor"] = "SiliconFlow" + create_response = client.post("/api/asr", json=sample_asr_model_data) + model_id = create_response.json()["id"] + + # Mock the HTTP response + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "results": [{"transcript": "测试文本", "language": "zh"}] + } + mock_response.raise_for_status = MagicMock() + mock_client.get.return_value = mock_response + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + + with patch('app.routers.asr.httpx.Client', return_value=mock_client): + response = client.post(f"/api/asr/{model_id}/test") + assert response.status_code == 200 + data = response.json() + assert data["success"] == True + + @patch('httpx.Client') + def test_test_asr_model_openai(self, mock_client_class, client, sample_asr_model_data): + """Test testing an ASR model with OpenAI vendor""" + # Create model with OpenAI vendor + sample_asr_model_data["vendor"] = "OpenAI" + sample_asr_model_data["id"] = "test-asr-openai" + create_response = client.post("/api/asr", json=sample_asr_model_data) + model_id = create_response.json()["id"] + + # Mock the HTTP response + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"text": "Test transcript"} + mock_response.raise_for_status = MagicMock() + mock_client.get.return_value = mock_response + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + + with patch('app.routers.asr.httpx.Client', return_value=mock_client): + response = client.post(f"/api/asr/{model_id}/test") + assert response.status_code == 200 + + @patch('httpx.Client') + def test_test_asr_model_failure(self, mock_client_class, client, sample_asr_model_data): + """Test testing an ASR model with failed connection""" + # Create model first + create_response = client.post("/api/asr", json=sample_asr_model_data) + model_id = create_response.json()["id"] + + # Mock HTTP error + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.text = "Unauthorized" + mock_response.raise_for_status = MagicMock(side_effect=Exception("401 Unauthorized")) + mock_client.get.return_value = mock_response + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + + with patch('app.routers.asr.httpx.Client', return_value=mock_client): + response = client.post(f"/api/asr/{model_id}/test") + assert response.status_code == 200 + data = response.json() + assert data["success"] == False + + def test_different_asr_languages(self, client): + """Test creating ASR models with different languages""" + for lang in ["zh", "en", "Multi-lingual"]: + data = { + "id": f"asr-lang-{lang}", + "name": f"ASR {lang}", + "vendor": "SiliconFlow", + "language": lang, + "base_url": "https://api.siliconflow.cn/v1", + "api_key": "test-key" + } + response = client.post("/api/asr", json=data) + assert response.status_code == 200 + assert response.json()["language"] == lang + + def test_different_asr_vendors(self, client): + """Test creating ASR models with different vendors""" + vendors = ["SiliconFlow", "OpenAI", "Azure"] + for vendor in vendors: + data = { + "id": f"asr-vendor-{vendor.lower()}", + "name": f"ASR {vendor}", + "vendor": vendor, + "language": "zh", + "base_url": f"https://api.{vendor.lower()}.com/v1", + "api_key": "test-key" + } + response = client.post("/api/asr", json=data) + assert response.status_code == 200 + assert response.json()["vendor"] == vendor diff --git a/api/tests/test_llm.py b/api/tests/test_llm.py new file mode 100644 index 0000000..d626928 --- /dev/null +++ b/api/tests/test_llm.py @@ -0,0 +1,246 @@ +"""Tests for LLM Model API endpoints""" +import pytest +from unittest.mock import patch, MagicMock + + +class TestLLMModelAPI: + """Test cases for LLM Model endpoints""" + + def test_get_llm_models_empty(self, client): + """Test getting LLM models when database is empty""" + response = client.get("/api/llm") + assert response.status_code == 200 + data = response.json() + assert "total" in data + assert "list" in data + assert data["total"] == 0 + + def test_create_llm_model(self, client, sample_llm_model_data): + """Test creating a new LLM model""" + response = client.post("/api/llm", json=sample_llm_model_data) + assert response.status_code == 200 + data = response.json() + assert data["name"] == sample_llm_model_data["name"] + assert data["vendor"] == sample_llm_model_data["vendor"] + assert data["type"] == sample_llm_model_data["type"] + assert data["base_url"] == sample_llm_model_data["base_url"] + assert "id" in data + + def test_create_llm_model_minimal(self, client): + """Test creating an LLM model with minimal required data""" + data = { + "name": "Minimal LLM", + "vendor": "Test", + "type": "text", + "base_url": "https://api.test.com", + "api_key": "test-key" + } + response = client.post("/api/llm", json=data) + assert response.status_code == 200 + assert response.json()["name"] == "Minimal LLM" + + def test_get_llm_model_by_id(self, client, sample_llm_model_data): + """Test getting a specific LLM model by ID""" + # Create first + create_response = client.post("/api/llm", json=sample_llm_model_data) + model_id = create_response.json()["id"] + + # Get by ID + response = client.get(f"/api/llm/{model_id}") + assert response.status_code == 200 + data = response.json() + assert data["id"] == model_id + assert data["name"] == sample_llm_model_data["name"] + + def test_get_llm_model_not_found(self, client): + """Test getting a non-existent LLM model""" + response = client.get("/api/llm/non-existent-id") + assert response.status_code == 404 + + def test_update_llm_model(self, client, sample_llm_model_data): + """Test updating an LLM model""" + # Create first + create_response = client.post("/api/llm", json=sample_llm_model_data) + model_id = create_response.json()["id"] + + # Update + update_data = { + "name": "Updated LLM Model", + "temperature": 0.5, + "context_length": 8192 + } + response = client.put(f"/api/llm/{model_id}", json=update_data) + assert response.status_code == 200 + data = response.json() + assert data["name"] == "Updated LLM Model" + assert data["temperature"] == 0.5 + assert data["context_length"] == 8192 + + def test_delete_llm_model(self, client, sample_llm_model_data): + """Test deleting an LLM model""" + # Create first + create_response = client.post("/api/llm", json=sample_llm_model_data) + model_id = create_response.json()["id"] + + # Delete + response = client.delete(f"/api/llm/{model_id}") + assert response.status_code == 200 + + # Verify deleted + get_response = client.get(f"/api/llm/{model_id}") + assert get_response.status_code == 404 + + def test_list_llm_models_with_pagination(self, client, sample_llm_model_data): + """Test listing LLM models with pagination""" + # Create multiple models + for i in range(3): + data = sample_llm_model_data.copy() + data["id"] = f"test-llm-{i}" + data["name"] = f"LLM Model {i}" + client.post("/api/llm", json=data) + + # Test pagination + response = client.get("/api/llm?page=1&limit=2") + assert response.status_code == 200 + data = response.json() + assert data["total"] == 3 + assert len(data["list"]) == 2 + + def test_filter_llm_models_by_type(self, client, sample_llm_model_data): + """Test filtering LLM models by type""" + # Create models with different types + for i, model_type in enumerate(["text", "embedding", "rerank"]): + data = sample_llm_model_data.copy() + data["id"] = f"test-llm-{model_type}" + data["name"] = f"LLM {model_type}" + data["type"] = model_type + client.post("/api/llm", json=data) + + # Filter by type + response = client.get("/api/llm?model_type=text") + assert response.status_code == 200 + data = response.json() + assert data["total"] >= 1 + for model in data["list"]: + assert model["type"] == "text" + + def test_filter_llm_models_by_enabled(self, client, sample_llm_model_data): + """Test filtering LLM models by enabled status""" + # Create enabled and disabled models + data = sample_llm_model_data.copy() + data["id"] = "test-llm-enabled" + data["name"] = "Enabled LLM" + data["enabled"] = True + client.post("/api/llm", json=data) + + data["id"] = "test-llm-disabled" + data["name"] = "Disabled LLM" + data["enabled"] = False + client.post("/api/llm", json=data) + + # Filter by enabled + response = client.get("/api/llm?enabled=true") + assert response.status_code == 200 + data = response.json() + for model in data["list"]: + assert model["enabled"] == True + + def test_create_llm_model_with_all_fields(self, client): + """Test creating an LLM model with all fields""" + data = { + "id": "full-llm", + "name": "Full LLM Model", + "vendor": "OpenAI", + "type": "text", + "base_url": "https://api.openai.com/v1", + "api_key": "sk-test", + "model_name": "gpt-4", + "temperature": 0.8, + "context_length": 16384, + "enabled": True + } + response = client.post("/api/llm", json=data) + assert response.status_code == 200 + result = response.json() + assert result["name"] == "Full LLM Model" + assert result["temperature"] == 0.8 + assert result["context_length"] == 16384 + + @patch('httpx.Client') + def test_test_llm_model_success(self, mock_client_class, client, sample_llm_model_data): + """Test testing an LLM model with successful connection""" + # Create model first + create_response = client.post("/api/llm", json=sample_llm_model_data) + model_id = create_response.json()["id"] + + # Mock the HTTP response + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [{"message": {"content": "OK"}}] + } + mock_response.raise_for_status = MagicMock() + mock_client.post.return_value = mock_response + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + + with patch('app.routers.llm.httpx.Client', return_value=mock_client): + response = client.post(f"/api/llm/{model_id}/test") + assert response.status_code == 200 + data = response.json() + assert data["success"] == True + + @patch('httpx.Client') + def test_test_llm_model_failure(self, mock_client_class, client, sample_llm_model_data): + """Test testing an LLM model with failed connection""" + # Create model first + create_response = client.post("/api/llm", json=sample_llm_model_data) + model_id = create_response.json()["id"] + + # Mock HTTP error + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.text = "Unauthorized" + mock_response.raise_for_status = MagicMock(side_effect=Exception("401 Unauthorized")) + mock_client.post.return_value = mock_response + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + + with patch('app.routers.llm.httpx.Client', return_value=mock_client): + response = client.post(f"/api/llm/{model_id}/test") + assert response.status_code == 200 + data = response.json() + assert data["success"] == False + + def test_different_llm_vendors(self, client): + """Test creating LLM models with different vendors""" + vendors = ["OpenAI", "SiliconFlow", "ZhipuAI", "Anthropic"] + for vendor in vendors: + data = { + "id": f"test-{vendor.lower()}", + "name": f"Test {vendor}", + "vendor": vendor, + "type": "text", + "base_url": f"https://api.{vendor.lower()}.com/v1", + "api_key": "test-key" + } + response = client.post("/api/llm", json=data) + assert response.status_code == 200 + assert response.json()["vendor"] == vendor + + def test_embedding_llm_model(self, client): + """Test creating an embedding LLM model""" + data = { + "id": "embedding-test", + "name": "Embedding Model", + "vendor": "OpenAI", + "type": "embedding", + "base_url": "https://api.openai.com/v1", + "api_key": "test-key", + "model_name": "text-embedding-3-small" + } + response = client.post("/api/llm", json=data) + assert response.status_code == 200 + assert response.json()["type"] == "embedding" diff --git a/api/tests/test_tools.py b/api/tests/test_tools.py new file mode 100644 index 0000000..9ce6cb4 --- /dev/null +++ b/api/tests/test_tools.py @@ -0,0 +1,267 @@ +"""Tests for Tools & Autotest API endpoints""" +import pytest +from unittest.mock import patch, MagicMock + + +class TestToolsAPI: + """Test cases for Tools endpoints""" + + def test_list_available_tools(self, client): + """Test listing all available tools""" + response = client.get("/api/tools/list") + assert response.status_code == 200 + data = response.json() + assert "tools" in data + # Check for expected tools + tools = data["tools"] + assert "search" in tools + assert "calculator" in tools + assert "weather" in tools + + def test_get_tool_detail(self, client): + """Test getting a specific tool's details""" + response = client.get("/api/tools/list/search") + assert response.status_code == 200 + data = response.json() + assert data["name"] == "网络搜索" + assert "parameters" in data + + def test_get_tool_detail_not_found(self, client): + """Test getting a non-existent tool""" + response = client.get("/api/tools/list/non-existent-tool") + assert response.status_code == 404 + + def test_health_check(self, client): + """Test health check endpoint""" + response = client.get("/api/tools/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + assert "timestamp" in data + assert "tools" in data + + +class TestAutotestAPI: + """Test cases for Autotest endpoints""" + + def test_autotest_no_models(self, client): + """Test autotest without specifying model IDs""" + response = client.post("/api/tools/autotest") + assert response.status_code == 200 + data = response.json() + assert "id" in data + assert "tests" in data + assert "summary" in data + # Should have test failures since no models provided + assert data["summary"]["total"] > 0 + + def test_autotest_with_llm_model(self, client, sample_llm_model_data): + """Test autotest with an LLM model""" + # Create an LLM model first + create_response = client.post("/api/llm", json=sample_llm_model_data) + model_id = create_response.json()["id"] + + # Run autotest + response = client.post(f"/api/tools/autotest?llm_model_id={model_id}&test_asr=false") + assert response.status_code == 200 + data = response.json() + assert "tests" in data + assert "summary" in data + + def test_autotest_with_asr_model(self, client, sample_asr_model_data): + """Test autotest with an ASR model""" + # Create an ASR model first + create_response = client.post("/api/asr", json=sample_asr_model_data) + model_id = create_response.json()["id"] + + # Run autotest + response = client.post(f"/api/tools/autotest?asr_model_id={model_id}&test_llm=false") + assert response.status_code == 200 + data = response.json() + assert "tests" in data + assert "summary" in data + + def test_autotest_with_both_models(self, client, sample_llm_model_data, sample_asr_model_data): + """Test autotest with both LLM and ASR models""" + # Create models + llm_response = client.post("/api/llm", json=sample_llm_model_data) + llm_id = llm_response.json()["id"] + + asr_response = client.post("/api/asr", json=sample_asr_model_data) + asr_id = asr_response.json()["id"] + + # Run autotest + response = client.post( + f"/api/tools/autotest?llm_model_id={llm_id}&asr_model_id={asr_id}" + ) + assert response.status_code == 200 + data = response.json() + assert "tests" in data + assert "summary" in data + + @patch('httpx.Client') + def test_autotest_llm_model_success(self, mock_client_class, client, sample_llm_model_data): + """Test autotest for a specific LLM model with successful connection""" + # Create an LLM model first + create_response = client.post("/api/llm", json=sample_llm_model_data) + model_id = create_response.json()["id"] + + # Mock the HTTP response for successful connection + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [{"message": {"content": "OK"}}] + } + mock_response.raise_for_status = MagicMock() + mock_response.iter_bytes = MagicMock(return_value=[b'chunk1', b'chunk2']) + mock_client.post.return_value = mock_response + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + + with patch('app.routers.tools.httpx.Client', return_value=mock_client): + response = client.post(f"/api/tools/autotest/llm/{model_id}") + assert response.status_code == 200 + data = response.json() + assert "tests" in data + assert "summary" in data + + @patch('httpx.Client') + def test_autotest_asr_model_success(self, mock_client_class, client, sample_asr_model_data): + """Test autotest for a specific ASR model with successful connection""" + # Create an ASR model first + create_response = client.post("/api/asr", json=sample_asr_model_data) + model_id = create_response.json()["id"] + + # Mock the HTTP response for successful connection + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_client.get.return_value = mock_response + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + + with patch('app.routers.tools.httpx.Client', return_value=mock_client): + response = client.post(f"/api/tools/autotest/asr/{model_id}") + assert response.status_code == 200 + data = response.json() + assert "tests" in data + assert "summary" in data + + def test_autotest_llm_model_not_found(self, client): + """Test autotest for a non-existent LLM model""" + response = client.post("/api/tools/autotest/llm/non-existent-id") + assert response.status_code == 200 + data = response.json() + # Should have a failure test + assert any(not t["passed"] for t in data["tests"]) + + def test_autotest_asr_model_not_found(self, client): + """Test autotest for a non-existent ASR model""" + response = client.post("/api/tools/autotest/asr/non-existent-id") + assert response.status_code == 200 + data = response.json() + # Should have a failure test + assert any(not t["passed"] for t in data["tests"]) + + @patch('httpx.Client') + def test_test_message_success(self, mock_client_class, client, sample_llm_model_data): + """Test sending a test message to an LLM model""" + # Create an LLM model first + create_response = client.post("/api/llm", json=sample_llm_model_data) + model_id = create_response.json()["id"] + + # Mock the HTTP response + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [{"message": {"content": "Hello! This is a test reply."}}], + "usage": {"total_tokens": 10} + } + mock_response.raise_for_status = MagicMock() + mock_client.post.return_value = mock_response + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + + with patch('app.routers.tools.httpx.Client', return_value=mock_client): + response = client.post( + f"/api/tools/test-message?llm_model_id={model_id}", + json={"message": "Hello!"} + ) + assert response.status_code == 200 + data = response.json() + assert data["success"] == True + assert "reply" in data + + def test_test_message_model_not_found(self, client): + """Test sending a test message to a non-existent model""" + response = client.post( + "/api/tools/test-message?llm_model_id=non-existent", + json={"message": "Hello!"} + ) + assert response.status_code == 404 + + def test_autotest_result_structure(self, client): + """Test that autotest results have the correct structure""" + response = client.post("/api/tools/autotest") + assert response.status_code == 200 + data = response.json() + + # Check required fields + assert "id" in data + assert "started_at" in data + assert "duration_ms" in data + assert "tests" in data + assert "summary" in data + + # Check summary structure + assert "passed" in data["summary"] + assert "failed" in data["summary"] + assert "total" in data["summary"] + + # Check test structure + if data["tests"]: + test = data["tests"][0] + assert "name" in test + assert "passed" in test + assert "message" in test + assert "duration_ms" in test + + def test_tools_have_required_fields(self, client): + """Test that all tools have required fields""" + response = client.get("/api/tools/list") + assert response.status_code == 200 + data = response.json() + + for tool_id, tool in data["tools"].items(): + assert "name" in tool + assert "description" in tool + assert "parameters" in tool + + # Check parameters structure + params = tool["parameters"] + assert "type" in params + assert "properties" in params + + def test_calculator_tool_parameters(self, client): + """Test calculator tool has correct parameters""" + response = client.get("/api/tools/list/calculator") + assert response.status_code == 200 + data = response.json() + + assert data["name"] == "计算器" + assert "expression" in data["parameters"]["properties"] + assert "required" in data["parameters"] + assert "expression" in data["parameters"]["required"] + + def test_translate_tool_parameters(self, client): + """Test translate tool has correct parameters""" + response = client.get("/api/tools/list/translate") + assert response.status_code == 200 + data = response.json() + + assert data["name"] == "翻译" + assert "text" in data["parameters"]["properties"] + assert "target_lang" in data["parameters"]["properties"]