Compare commits

...

3 Commits

Author SHA1 Message Date
Xin Wang
3d8635670f Merge branch 'master' of https://gitea.xiaowang.eu.org/wx44wx/AI-VideoAssistant 2026-02-08 15:54:01 +08:00
Xin Wang
7012f8edaf Update backend api 2026-02-08 15:52:16 +08:00
Xin Wang
727fe8a997 Update backend schema 2026-02-08 14:26:19 +08:00
28 changed files with 4955 additions and 100 deletions

View File

@@ -1,4 +1,4 @@
FROM python:3.11-slim FROM python:3.12-slim
WORKDIR /app WORKDIR /app
@@ -12,6 +12,6 @@ COPY . .
# 创建数据目录 # 创建数据目录
RUN mkdir -p /app/data RUN mkdir -p /app/data
EXPOSE 8000 EXPOSE 8100
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"] CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8100", "--reload"]

View File

@@ -1,13 +1,13 @@
# AI VideoAssistant Backend # AI VideoAssistant Backend
Python 后端 API配合前端 `ai-videoassistant-frontend` 使用。 Python 后端 API配合前端 `web/` 模块使用。
## 快速开始 ## 快速开始
### 1. 安装依赖 ### 1. 安装依赖
```bash ```bash
cd ~/Code/ai-videoassistant-backend cd api
pip install -r requirements.txt pip install -r requirements.txt
``` ```
@@ -25,44 +25,162 @@ python init_db.py
```bash ```bash
# 开发模式 (热重载) # 开发模式 (热重载)
python -m uvicorn main:app --reload --host 0.0.0.0 --port 8000 python -m uvicorn main:app --reload --host 0.0.0.0 --port 8100
``` ```
服务运行在: http://localhost:8100
### 4. 测试 API ### 4. 测试 API
```bash ```bash
# 健康检查 # 健康检查
curl http://localhost:8000/health curl http://localhost:8100/health
# 获取助手列表 # 获取助手列表
curl http://localhost:8000/api/assistants curl http://localhost:8100/api/assistants
# 获取声音列表 # 获取声音列表
curl http://localhost:8000/api/voices curl http://localhost:8100/api/voices
# 获取通话历史 # 获取通话历史
curl http://localhost:8000/api/history curl http://localhost:8100/api/history
``` ```
---
## API 文档 ## API 文档
| 端点 | 方法 | 说明 | 完整 API 文档位于 [docs/](docs/) 目录:
| 模块 | 端点 | 方法 | 说明 |
|------|------|------|------|
| **Assistant** | `/api/assistants` | GET | 助手列表 |
| | | POST | 创建助手 |
| | `/api/assistants/{id}` | GET | 助手详情 |
| | | PUT | 更新助手 |
| | | DELETE | 删除助手 |
| **Voice** | `/api/voices` | GET | 声音库列表 |
| | | POST | 添加声音 |
| | `/api/voices/{id}` | GET | 声音详情 |
| | | PUT | 更新声音 |
| | | DELETE | 删除声音 |
| | `/api/voices/{id}/preview` | POST | 预览声音 |
| **LLM Models** | `/api/models/llm` | GET | LLM 模型列表 |
| | | POST | 添加模型 |
| | `/api/models/llm/{id}` | GET | 模型详情 |
| | | PUT | 更新模型 |
| | | DELETE | 删除模型 |
| | `/api/models/llm/{id}/test` | POST | 测试模型连接 |
| **ASR Models** | `/api/models/asr` | GET | ASR 模型列表 |
| | | POST | 添加模型 |
| | `/api/models/asr/{id}` | GET | 模型详情 |
| | | PUT | 更新模型 |
| | | DELETE | 删除模型 |
| | `/api/models/asr/{id}/test` | POST | 测试识别 |
| **History** | `/api/history` | GET | 通话历史列表 |
| | `/api/history/{id}` | GET | 通话详情 |
| | | PUT | 更新通话记录 |
| | | DELETE | 删除记录 |
| | `/api/history/{id}/transcripts` | POST | 添加转写 |
| | `/api/history/search` | GET | 搜索历史 |
| | `/api/history/stats` | GET | 统计数据 |
| **Knowledge** | `/api/knowledge/bases` | GET | 知识库列表 |
| | | POST | 创建知识库 |
| | `/api/knowledge/bases/{id}` | GET | 知识库详情 |
| | | PUT | 更新知识库 |
| | | DELETE | 删除知识库 |
| | `/api/knowledge/bases/{kb_id}/documents` | POST | 上传文档 |
| | `/api/knowledge/bases/{kb_id}/documents/{doc_id}` | DELETE | 删除文档 |
| | `/api/knowledge/bases/{kb_id}/documents/{doc_id}/index` | POST | 索引文档 |
| | `/api/knowledge/search` | POST | 知识搜索 |
| **Workflow** | `/api/workflows` | GET | 工作流列表 |
| | | POST | 创建工作流 |
| | `/api/workflows/{id}` | GET | 工作流详情 |
| | | PUT | 更新工作流 |
| | | DELETE | 删除工作流 |
---
## 数据模型
### Assistant (小助手)
| 字段 | 类型 | 说明 |
|------|------|------| |------|------|------|
| `/api/assistants` | GET | 助手列表 | | id | string | 助手 ID |
| `/api/assistants` | POST | 创建助手 | | name | string | 助手名称 |
| `/api/assistants/{id}` | GET | 助手详情 | | opener | string | 开场白 |
| `/api/assistants/{id}` | PUT | 更新助手 | | prompt | string | 系统提示词 |
| `/api/assistants/{id}` | DELETE | 删除助手 | | knowledgeBaseId | string | 关联知识库 ID |
| `/api/voices` | GET | 声音库列表 | | language | string | 语言: zh/en |
| `/api/history` | GET | 通话历史列表 | | voice | string | 声音 ID |
| `/api/history/{id}` | GET | 通话详情 | | speed | float | 语速 (0.5-2.0) |
| `/api/history/{id}/transcripts` | POST | 添加转写 | | hotwords | array | 热词列表 |
| `/api/history/{id}/audio/{turn}` | GET | 获取音频 | | tools | array | 启用的工具列表 |
| llmModelId | string | LLM 模型 ID |
| asrModelId | string | ASR 模型 ID |
| embeddingModelId | string | Embedding 模型 ID |
| rerankModelId | string | Rerank 模型 ID |
### Voice (声音资源)
| 字段 | 类型 | 说明 |
|------|------|------|
| id | string | 声音 ID |
| name | string | 声音名称 |
| vendor | string | 厂商: Ali/Volcano/Minimax |
| gender | string | 性别: Male/Female |
| language | string | 语言: zh/en |
| model | string | 厂商模型标识 |
| voice_key | string | 厂商 voice_key |
| speed | float | 语速 |
| gain | int | 增益 (dB) |
| pitch | int | 音调 |
### LLMModel (模型接入)
| 字段 | 类型 | 说明 |
|------|------|------|
| id | string | 模型 ID |
| name | string | 模型名称 |
| vendor | string | 厂商 |
| type | string | 类型: text/embedding/rerank |
| base_url | string | API 地址 |
| api_key | string | API 密钥 |
| model_name | string | 模型名称 |
| temperature | float | 温度参数 |
### ASRModel (语音识别)
| 字段 | 类型 | 说明 |
|------|------|------|
| id | string | 模型 ID |
| name | string | 模型名称 |
| vendor | string | 厂商 |
| language | string | 语言: zh/en/Multi-lingual |
| base_url | string | API 地址 |
| api_key | string | API 密钥 |
| hotwords | array | 热词列表 |
### CallRecord (通话记录)
| 字段 | 类型 | 说明 |
|------|------|------|
| id | string | 记录 ID |
| assistant_id | string | 助手 ID |
| source | string | 来源: debug/external |
| status | string | 状态: connected/missed/failed |
| started_at | string | 开始时间 |
| duration_seconds | int | 通话时长 |
| summary | string | 通话摘要 |
| transcripts | array | 对话转写 |
---
## 使用 Docker 启动 ## 使用 Docker 启动
```bash ```bash
cd ~/Code/ai-videoassistant-backend cd api
# 启动所有服务 # 启动所有服务
docker-compose up -d docker-compose up -d
@@ -71,33 +189,144 @@ docker-compose up -d
docker-compose logs -f backend docker-compose logs -f backend
``` ```
---
## 目录结构 ## 目录结构
``` ```
backend/ api/
├── app/ ├── app/
│ ├── __init__.py │ ├── __init__.py
│ ├── main.py # FastAPI 入口 │ ├── main.py # FastAPI 入口
│ ├── db.py # SQLite 连接 │ ├── db.py # SQLite 连接
│ ├── models.py # 数据模型 │ ├── models.py # SQLAlchemy 数据模型
│ ├── schemas.py # Pydantic 模型 │ ├── schemas.py # Pydantic 模型
│ ├── storage.py # MinIO 存储 │ ├── storage.py # MinIO 存储
│ ├── vector_store.py # 向量存储
│ └── routers/ │ └── routers/
│ ├── __init__.py │ ├── __init__.py
│ ├── assistants.py # 助手 API │ ├── assistants.py # 助手 API
── history.py # 通话记录 API ── history.py # 通话记录 API
│ └── knowledge.py # 知识库 API
├── data/ # 数据库文件 ├── data/ # 数据库文件
├── docs/ # API 文档
├── requirements.txt ├── requirements.txt
├── .env ├── .env
├── init_db.py
├── main.py
└── docker-compose.yml └── docker-compose.yml
``` ```
---
## 环境变量 ## 环境变量
| 变量 | 默认值 | 说明 | | 变量 | 默认值 | 说明 |
|------|--------|------| |------|--------|------|
| `PORT` | `8100` | 服务端口 |
| `DATABASE_URL` | `sqlite:///./data/app.db` | 数据库连接 | | `DATABASE_URL` | `sqlite:///./data/app.db` | 数据库连接 |
| `MINIO_ENDPOINT` | `localhost:9000` | MinIO 地址 | | `MINIO_ENDPOINT` | `localhost:9000` | MinIO 地址 |
| `MINIO_ACCESS_KEY` | `admin` | MinIO 密钥 | | `MINIO_ACCESS_KEY` | `admin` | MinIO 密钥 |
| `MINIO_SECRET_KEY` | `password123` | MinIO 密码 | | `MINIO_SECRET_KEY` | `password123` | MinIO 密码 |
| `MINIO_BUCKET` | `ai-audio` | 存储桶名称 | | `MINIO_BUCKET` | `ai-audio` | 存储桶名称 |
---
## 数据库迁移
开发环境重新创建数据库:
```bash
rm -f api/data/app.db
python api/init_db.py
```
---
## 测试
### 安装测试依赖
```bash
cd api
pip install pytest pytest-cov -q
```
### 运行所有测试
```bash
# Windows
run_tests.bat
# 或使用 pytest
pytest tests/ -v
```
### 运行特定测试
```bash
# 只测试声音 API
pytest tests/test_voices.py -v
# 只测试助手 API
pytest tests/test_assistants.py -v
# 只测试历史记录 API
pytest tests/test_history.py -v
# 只测试知识库 API
pytest tests/test_knowledge.py -v
```
### 测试覆盖率
```bash
pytest tests/ --cov=app --cov-report=html
# 查看报告: open htmlcov/index.html
```
### 测试目录结构
```
tests/
├── __init__.py
├── conftest.py # pytest fixtures
├── test_voices.py # 声音 API 测试
├── test_assistants.py # 助手 API 测试
├── test_history.py # 历史记录 API 测试
└── test_knowledge.py # 知识库 API 测试
```
### 测试用例统计
| 模块 | 测试用例数 |
|------|-----------|
| Voice | 13 |
| Assistant | 14 |
| History | 18 |
| Knowledge | 19 |
| **总计** | **64** |
### CI/CD 示例 (.github/workflows/test.yml)
```yaml
name: Tests
on: [push, pull_request]
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Install dependencies
run: |
pip install -r api/requirements.txt
pip install pytest pytest-cov
- name: Run tests
run: pytest api/tests/ -v --cov=app
```

View File

@@ -1,7 +1,10 @@
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, DeclarativeBase from sqlalchemy.orm import sessionmaker, DeclarativeBase
import os
DATABASE_URL = "sqlite:///./data/app.db" # 使用绝对路径
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DATABASE_URL = f"sqlite:///{os.path.join(BASE_DIR, 'data', 'app.db')}"
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

View File

@@ -4,7 +4,7 @@ from contextlib import asynccontextmanager
import os import os
from .db import Base, engine from .db import Base, engine
from .routers import assistants, history from .routers import assistants, history, knowledge, llm, asr, tools
@asynccontextmanager @asynccontextmanager
@@ -33,6 +33,10 @@ app.add_middleware(
# 路由 # 路由
app.include_router(assistants.router, prefix="/api") app.include_router(assistants.router, prefix="/api")
app.include_router(history.router, prefix="/api") app.include_router(history.router, prefix="/api")
app.include_router(knowledge.router, prefix="/api")
app.include_router(llm.router, prefix="/api")
app.include_router(asr.router, prefix="/api")
app.include_router(tools.router, prefix="/api")
@app.get("/") @app.get("/")
@@ -42,31 +46,4 @@ def root():
@app.get("/health") @app.get("/health")
def health(): def health():
return {"status": "ok"} return {"status": "ok"}
# 初始化默认数据
@app.on_event("startup")
def init_default_data():
from sqlalchemy.orm import Session
from .db import SessionLocal
from .models import Voice
db = SessionLocal()
try:
# 检查是否已有数据
if db.query(Voice).count() == 0:
# 插入默认声音
voices = [
Voice(id="v1", name="Xiaoyun", vendor="Ali", gender="Female", language="zh", description="Gentle and professional."),
Voice(id="v2", name="Kevin", vendor="Volcano", gender="Male", language="en", description="Deep and authoritative."),
Voice(id="v3", name="Abby", vendor="Minimax", gender="Female", language="en", description="Cheerful and lively."),
Voice(id="v4", name="Guang", vendor="Ali", gender="Male", language="zh", description="Standard newscast style."),
Voice(id="v5", name="Doubao", vendor="Volcano", gender="Female", language="zh", description="Cute and young."),
]
for v in voices:
db.add(v)
db.commit()
print("✅ 默认声音数据已初始化")
finally:
db.close()

View File

@@ -1,6 +1,6 @@
from datetime import datetime from datetime import datetime
from typing import List, Optional from typing import List, Optional
from sqlalchemy import String, Integer, DateTime, Text, Float, ForeignKey, JSON from sqlalchemy import String, Integer, DateTime, Text, Float, ForeignKey, JSON, Enum
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from .db import Base from .db import Base
@@ -15,18 +15,72 @@ class User(Base):
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
# ============ Voice ============
class Voice(Base): class Voice(Base):
__tablename__ = "voices" __tablename__ = "voices"
id: Mapped[str] = mapped_column(String(64), primary_key=True) id: Mapped[str] = mapped_column(String(64), primary_key=True)
user_id: Mapped[Optional[int]] = mapped_column(Integer, ForeignKey("users.id"), index=True, nullable=True)
name: Mapped[str] = mapped_column(String(128), nullable=False) name: Mapped[str] = mapped_column(String(128), nullable=False)
vendor: Mapped[str] = mapped_column(String(64), nullable=False) vendor: Mapped[str] = mapped_column(String(64), nullable=False)
gender: Mapped[str] = mapped_column(String(32), nullable=False) gender: Mapped[str] = mapped_column(String(32), nullable=False)
language: Mapped[str] = mapped_column(String(16), nullable=False) language: Mapped[str] = mapped_column(String(16), nullable=False)
description: Mapped[str] = mapped_column(String(255), nullable=False) description: Mapped[str] = mapped_column(String(255), nullable=False)
voice_params: Mapped[dict] = mapped_column(JSON, default=dict) model: Mapped[Optional[str]] = mapped_column(String(128), nullable=True) # 厂商语音模型标识
voice_key: Mapped[Optional[str]] = mapped_column(String(128), nullable=True) # 厂商voice_key
speed: Mapped[float] = mapped_column(Float, default=1.0)
gain: Mapped[int] = mapped_column(Integer, default=0)
pitch: Mapped[int] = mapped_column(Integer, default=0)
enabled: Mapped[bool] = mapped_column(default=True)
is_system: Mapped[bool] = mapped_column(default=False)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
user = relationship("User", foreign_keys=[user_id])
# ============ LLM Model ============
class LLMModel(Base):
__tablename__ = "llm_models"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), index=True)
name: Mapped[str] = mapped_column(String(128), nullable=False)
vendor: Mapped[str] = mapped_column(String(64), nullable=False)
type: Mapped[str] = mapped_column(String(32), nullable=False) # text/embedding/rerank
base_url: Mapped[str] = mapped_column(String(512), nullable=False)
api_key: Mapped[str] = mapped_column(String(512), nullable=False)
model_name: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
temperature: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
context_length: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
enabled: Mapped[bool] = mapped_column(default=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
user = relationship("User")
# ============ ASR Model ============
class ASRModel(Base):
__tablename__ = "asr_models"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), index=True)
name: Mapped[str] = mapped_column(String(128), nullable=False)
vendor: Mapped[str] = mapped_column(String(64), nullable=False)
language: Mapped[str] = mapped_column(String(32), nullable=False) # zh/en/Multi-lingual
base_url: Mapped[str] = mapped_column(String(512), nullable=False)
api_key: Mapped[str] = mapped_column(String(512), nullable=False)
model_name: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
hotwords: Mapped[dict] = mapped_column(JSON, default=list)
enable_punctuation: Mapped[bool] = mapped_column(default=True)
enable_normalization: Mapped[bool] = mapped_column(default=True)
enabled: Mapped[bool] = mapped_column(default=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
user = relationship("User")
# ============ Assistant ============
class Assistant(Base): class Assistant(Base):
__tablename__ = "assistants" __tablename__ = "assistants"
@@ -46,6 +100,11 @@ class Assistant(Base):
config_mode: Mapped[str] = mapped_column(String(32), default="platform") config_mode: Mapped[str] = mapped_column(String(32), default="platform")
api_url: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) api_url: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
api_key: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) api_key: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
# 模型关联
llm_model_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
asr_model_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
embedding_model_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
rerank_model_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
@@ -53,6 +112,7 @@ class Assistant(Base):
call_records = relationship("CallRecord", back_populates="assistant") call_records = relationship("CallRecord", back_populates="assistant")
# ============ Knowledge Base ============
class KnowledgeBase(Base): class KnowledgeBase(Base):
__tablename__ = "knowledge_bases" __tablename__ = "knowledge_bases"
@@ -92,6 +152,7 @@ class KnowledgeDocument(Base):
kb = relationship("KnowledgeBase", back_populates="documents") kb = relationship("KnowledgeBase", back_populates="documents")
# ============ Workflow ============
class Workflow(Base): class Workflow(Base):
__tablename__ = "workflows" __tablename__ = "workflows"
@@ -108,6 +169,7 @@ class Workflow(Base):
user = relationship("User") user = relationship("User")
# ============ Call Record ============
class CallRecord(Base): class CallRecord(Base):
__tablename__ = "call_records" __tablename__ = "call_records"

View File

@@ -3,9 +3,15 @@ from fastapi import APIRouter
from . import assistants from . import assistants
from . import history from . import history
from . import knowledge from . import knowledge
from . import llm
from . import asr
from . import tools
router = APIRouter() router = APIRouter()
router.include_router(assistants.router) router.include_router(assistants.router)
router.include_router(history.router) router.include_router(history.router)
router.include_router(knowledge.router) router.include_router(knowledge.router)
router.include_router(llm.router)
router.include_router(asr.router)
router.include_router(tools.router)

268
api/app/routers/asr.py Normal file
View File

@@ -0,0 +1,268 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import List, Optional
import uuid
import httpx
import time
import base64
import json
from datetime import datetime
from ..db import get_db
from ..models import ASRModel
from ..schemas import (
ASRModelCreate, ASRModelUpdate, ASRModelOut,
ASRTestRequest, ASRTestResponse, ListResponse
)
router = APIRouter(prefix="/asr", tags=["ASR Models"])
# ============ ASR Models CRUD ============
@router.get("", response_model=ListResponse)
def list_asr_models(
language: Optional[str] = None,
enabled: Optional[bool] = None,
page: int = 1,
limit: int = 50,
db: Session = Depends(get_db)
):
"""获取ASR模型列表"""
query = db.query(ASRModel)
if language:
query = query.filter(ASRModel.language == language)
if enabled is not None:
query = query.filter(ASRModel.enabled == enabled)
total = query.count()
models = query.order_by(ASRModel.created_at.desc()) \
.offset((page-1)*limit).limit(limit).all()
return {"total": total, "page": page, "limit": limit, "list": models}
@router.get("/{id}", response_model=ASRModelOut)
def get_asr_model(id: str, db: Session = Depends(get_db)):
"""获取单个ASR模型详情"""
model = db.query(ASRModel).filter(ASRModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="ASR Model not found")
return model
@router.post("", response_model=ASRModelOut)
def create_asr_model(data: ASRModelCreate, db: Session = Depends(get_db)):
"""创建ASR模型"""
asr_model = ASRModel(
id=data.id or str(uuid.uuid4())[:8],
user_id=1, # 默认用户
name=data.name,
vendor=data.vendor,
language=data.language,
base_url=data.base_url,
api_key=data.api_key,
model_name=data.model_name,
hotwords=data.hotwords,
enable_punctuation=data.enable_punctuation,
enable_normalization=data.enable_normalization,
enabled=data.enabled,
)
db.add(asr_model)
db.commit()
db.refresh(asr_model)
return asr_model
@router.put("/{id}", response_model=ASRModelOut)
def update_asr_model(id: str, data: ASRModelUpdate, db: Session = Depends(get_db)):
"""更新ASR模型"""
model = db.query(ASRModel).filter(ASRModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="ASR Model not found")
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(model, field, value)
db.commit()
db.refresh(model)
return model
@router.delete("/{id}")
def delete_asr_model(id: str, db: Session = Depends(get_db)):
"""删除ASR模型"""
model = db.query(ASRModel).filter(ASRModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="ASR Model not found")
db.delete(model)
db.commit()
return {"message": "Deleted successfully"}
@router.post("/{id}/test", response_model=ASRTestResponse)
def test_asr_model(
id: str,
request: Optional[ASRTestRequest] = None,
db: Session = Depends(get_db)
):
"""测试ASR模型"""
model = db.query(ASRModel).filter(ASRModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="ASR Model not found")
start_time = time.time()
try:
# 根据不同的厂商构造不同的请求
if model.vendor.lower() in ["siliconflow", "paraformer"]:
# SiliconFlow/Paraformer 格式
payload = {
"model": model.model_name or "paraformer-v2",
"input": {},
"parameters": {
"hotwords": " ".join(model.hotwords) if model.hotwords else "",
"enable_punctuation": model.enable_punctuation,
"enable_normalization": model.enable_normalization,
}
}
# 如果有音频数据
if request and request.audio_data:
payload["input"]["file_urls"] = []
elif request and request.audio_url:
payload["input"]["url"] = request.audio_url
headers = {"Authorization": f"Bearer {model.api_key}"}
with httpx.Client(timeout=60.0) as client:
response = client.post(
f"{model.base_url}/asr",
json=payload,
headers=headers
)
response.raise_for_status()
result = response.json()
elif model.vendor.lower() == "openai":
# OpenAI Whisper 格式
headers = {"Authorization": f"Bearer {model.api_key}"}
# 准备文件
files = {}
if request and request.audio_data:
audio_bytes = base64.b64decode(request.audio_data)
files = {"file": ("audio.wav", audio_bytes, "audio/wav")}
data = {"model": model.model_name or "whisper-1"}
elif request and request.audio_url:
files = {"file": ("audio.wav", httpx.get(request.audio_url).content, "audio/wav")}
data = {"model": model.model_name or "whisper-1"}
else:
return ASRTestResponse(
success=False,
error="No audio data or URL provided"
)
with httpx.Client(timeout=60.0) as client:
response = client.post(
f"{model.base_url}/audio/transcriptions",
files=files,
data=data,
headers=headers
)
response.raise_for_status()
result = response.json()
result = {"results": [{"transcript": result.get("text", "")}]}
else:
# 通用格式(可根据需要扩展)
return ASRTestResponse(
success=False,
message=f"Unsupported vendor: {model.vendor}"
)
latency_ms = int((time.time() - start_time) * 1000)
# 解析结果
if result_data := result.get("results", [{}])[0]:
transcript = result_data.get("transcript", "")
return ASRTestResponse(
success=True,
transcript=transcript,
language=result_data.get("language", model.language),
confidence=result_data.get("confidence"),
latency_ms=latency_ms,
)
return ASRTestResponse(
success=False,
message="No transcript in response",
latency_ms=latency_ms
)
except httpx.HTTPStatusError as e:
return ASRTestResponse(
success=False,
error=f"HTTP Error: {e.response.status_code} - {e.response.text[:200]}"
)
except Exception as e:
return ASRTestResponse(
success=False,
error=str(e)[:200]
)
@router.post("/{id}/transcribe")
def transcribe_audio(
id: str,
audio_url: Optional[str] = None,
audio_data: Optional[str] = None,
hotwords: Optional[List[str]] = None,
db: Session = Depends(get_db)
):
"""转写音频"""
model = db.query(ASRModel).filter(ASRModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="ASR Model not found")
try:
payload = {
"model": model.model_name or "paraformer-v2",
"input": {},
"parameters": {
"hotwords": " ".join(hotwords or model.hotwords or []),
"enable_punctuation": model.enable_punctuation,
"enable_normalization": model.enable_normalization,
}
}
headers = {"Authorization": f"Bearer {model.api_key}"}
if audio_url:
payload["input"]["url"] = audio_url
elif audio_data:
payload["input"]["file_urls"] = []
with httpx.Client(timeout=120.0) as client:
response = client.post(
f"{model.base_url}/asr",
json=payload,
headers=headers
)
response.raise_for_status()
result = response.json()
if result_data := result.get("results", [{}])[0]:
return {
"success": True,
"transcript": result_data.get("transcript", ""),
"language": result_data.get("language", model.language),
"confidence": result_data.get("confidence"),
}
return {"success": False, "error": "No transcript in response"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -1,6 +1,6 @@
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import List from typing import List, Optional
import uuid import uuid
from datetime import datetime from datetime import datetime
@@ -8,7 +8,7 @@ from ..db import get_db
from ..models import Assistant, Voice, Workflow from ..models import Assistant, Voice, Workflow
from ..schemas import ( from ..schemas import (
AssistantCreate, AssistantUpdate, AssistantOut, AssistantCreate, AssistantUpdate, AssistantOut,
VoiceOut, VoiceCreate, VoiceUpdate, VoiceOut,
WorkflowCreate, WorkflowUpdate, WorkflowOut WorkflowCreate, WorkflowUpdate, WorkflowOut
) )
@@ -16,11 +16,88 @@ router = APIRouter()
# ============ Voices ============ # ============ Voices ============
@router.get("/voices", response_model=List[VoiceOut]) @router.get("/voices")
def list_voices(db: Session = Depends(get_db)): def list_voices(
vendor: Optional[str] = None,
language: Optional[str] = None,
gender: Optional[str] = None,
page: int = 1,
limit: int = 50,
db: Session = Depends(get_db)
):
"""获取声音库列表""" """获取声音库列表"""
voices = db.query(Voice).all() query = db.query(Voice)
return voices if vendor:
query = query.filter(Voice.vendor == vendor)
if language:
query = query.filter(Voice.language == language)
if gender:
query = query.filter(Voice.gender == gender)
total = query.count()
voices = query.order_by(Voice.created_at.desc()) \
.offset((page-1)*limit).limit(limit).all()
return {"total": total, "page": page, "limit": limit, "list": voices}
@router.post("/voices", response_model=VoiceOut)
def create_voice(data: VoiceCreate, db: Session = Depends(get_db)):
"""创建声音"""
voice = Voice(
id=data.id or str(uuid.uuid4())[:8],
user_id=1,
name=data.name,
vendor=data.vendor,
gender=data.gender,
language=data.language,
description=data.description,
model=data.model,
voice_key=data.voice_key,
speed=data.speed,
gain=data.gain,
pitch=data.pitch,
enabled=data.enabled,
)
db.add(voice)
db.commit()
db.refresh(voice)
return voice
@router.get("/voices/{id}", response_model=VoiceOut)
def get_voice(id: str, db: Session = Depends(get_db)):
"""获取单个声音详情"""
voice = db.query(Voice).filter(Voice.id == id).first()
if not voice:
raise HTTPException(status_code=404, detail="Voice not found")
return voice
@router.put("/voices/{id}", response_model=VoiceOut)
def update_voice(id: str, data: VoiceUpdate, db: Session = Depends(get_db)):
"""更新声音"""
voice = db.query(Voice).filter(Voice.id == id).first()
if not voice:
raise HTTPException(status_code=404, detail="Voice not found")
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(voice, field, value)
db.commit()
db.refresh(voice)
return voice
@router.delete("/voices/{id}")
def delete_voice(id: str, db: Session = Depends(get_db)):
"""删除声音"""
voice = db.query(Voice).filter(Voice.id == id).first()
if not voice:
raise HTTPException(status_code=404, detail="Voice not found")
db.delete(voice)
db.commit()
return {"message": "Deleted successfully"}
# ============ Assistants ============ # ============ Assistants ============
@@ -79,11 +156,11 @@ def update_assistant(id: str, data: AssistantUpdate, db: Session = Depends(get_d
assistant = db.query(Assistant).filter(Assistant.id == id).first() assistant = db.query(Assistant).filter(Assistant.id == id).first()
if not assistant: if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found") raise HTTPException(status_code=404, detail="Assistant not found")
update_data = data.model_dump(exclude_unset=True) update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items(): for field, value in update_data.items():
setattr(assistant, field, value) setattr(assistant, field, value)
assistant.updated_at = datetime.utcnow() assistant.updated_at = datetime.utcnow()
db.commit() db.commit()
db.refresh(assistant) db.refresh(assistant)
@@ -103,10 +180,17 @@ def delete_assistant(id: str, db: Session = Depends(get_db)):
# ============ Workflows ============ # ============ Workflows ============
@router.get("/workflows", response_model=List[WorkflowOut]) @router.get("/workflows", response_model=List[WorkflowOut])
def list_workflows(db: Session = Depends(get_db)): def list_workflows(
page: int = 1,
limit: int = 50,
db: Session = Depends(get_db)
):
"""获取工作流列表""" """获取工作流列表"""
workflows = db.query(Workflow).all() query = db.query(Workflow)
return workflows total = query.count()
workflows = query.order_by(Workflow.created_at.desc()) \
.offset((page-1)*limit).limit(limit).all()
return {"total": total, "page": page, "limit": limit, "list": workflows}
@router.post("/workflows", response_model=WorkflowOut) @router.post("/workflows", response_model=WorkflowOut)
@@ -129,17 +213,26 @@ def create_workflow(data: WorkflowCreate, db: Session = Depends(get_db)):
return workflow return workflow
@router.get("/workflows/{id}", response_model=WorkflowOut)
def get_workflow(id: str, db: Session = Depends(get_db)):
"""获取单个工作流"""
workflow = db.query(Workflow).filter(Workflow.id == id).first()
if not workflow:
raise HTTPException(status_code=404, detail="Workflow not found")
return workflow
@router.put("/workflows/{id}", response_model=WorkflowOut) @router.put("/workflows/{id}", response_model=WorkflowOut)
def update_workflow(id: str, data: WorkflowUpdate, db: Session = Depends(get_db)): def update_workflow(id: str, data: WorkflowUpdate, db: Session = Depends(get_db)):
"""更新工作流""" """更新工作流"""
workflow = db.query(Workflow).filter(Workflow.id == id).first() workflow = db.query(Workflow).filter(Workflow.id == id).first()
if not workflow: if not workflow:
raise HTTPException(status_code=404, detail="Workflow not found") raise HTTPException(status_code=404, detail="Workflow not found")
update_data = data.model_dump(exclude_unset=True) update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items(): for field, value in update_data.items():
setattr(workflow, field, value) setattr(workflow, field, value)
workflow.updated_at = datetime.utcnow().isoformat() workflow.updated_at = datetime.utcnow().isoformat()
db.commit() db.commit()
db.refresh(workflow) db.refresh(workflow)

206
api/app/routers/llm.py Normal file
View File

@@ -0,0 +1,206 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import List, Optional
import uuid
import httpx
import time
from datetime import datetime
from ..db import get_db
from ..models import LLMModel
from ..schemas import (
LLMModelCreate, LLMModelUpdate, LLMModelOut,
LLMModelTestResponse, ListResponse
)
router = APIRouter(prefix="/llm", tags=["LLM Models"])
# ============ LLM Models CRUD ============
@router.get("", response_model=ListResponse)
def list_llm_models(
model_type: Optional[str] = None,
enabled: Optional[bool] = None,
page: int = 1,
limit: int = 50,
db: Session = Depends(get_db)
):
"""获取LLM模型列表"""
query = db.query(LLMModel)
if model_type:
query = query.filter(LLMModel.type == model_type)
if enabled is not None:
query = query.filter(LLMModel.enabled == enabled)
total = query.count()
models = query.order_by(LLMModel.created_at.desc()) \
.offset((page-1)*limit).limit(limit).all()
return {"total": total, "page": page, "limit": limit, "list": models}
@router.get("/{id}", response_model=LLMModelOut)
def get_llm_model(id: str, db: Session = Depends(get_db)):
"""获取单个LLM模型详情"""
model = db.query(LLMModel).filter(LLMModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="LLM Model not found")
return model
@router.post("", response_model=LLMModelOut)
def create_llm_model(data: LLMModelCreate, db: Session = Depends(get_db)):
"""创建LLM模型"""
llm_model = LLMModel(
id=data.id or str(uuid.uuid4())[:8],
user_id=1, # 默认用户
name=data.name,
vendor=data.vendor,
type=data.type.value if hasattr(data.type, 'value') else data.type,
base_url=data.base_url,
api_key=data.api_key,
model_name=data.model_name,
temperature=data.temperature,
context_length=data.context_length,
enabled=data.enabled,
)
db.add(llm_model)
db.commit()
db.refresh(llm_model)
return llm_model
@router.put("/{id}", response_model=LLMModelOut)
def update_llm_model(id: str, data: LLMModelUpdate, db: Session = Depends(get_db)):
"""更新LLM模型"""
model = db.query(LLMModel).filter(LLMModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="LLM Model not found")
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(model, field, value)
model.updated_at = datetime.utcnow()
db.commit()
db.refresh(model)
return model
@router.delete("/{id}")
def delete_llm_model(id: str, db: Session = Depends(get_db)):
"""删除LLM模型"""
model = db.query(LLMModel).filter(LLMModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="LLM Model not found")
db.delete(model)
db.commit()
return {"message": "Deleted successfully"}
@router.post("/{id}/test", response_model=LLMModelTestResponse)
def test_llm_model(id: str, db: Session = Depends(get_db)):
"""测试LLM模型连接"""
model = db.query(LLMModel).filter(LLMModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="LLM Model not found")
start_time = time.time()
try:
# 构造测试请求
test_messages = [{"role": "user", "content": "Hello, please reply with 'OK'."}]
payload = {
"model": model.model_name or "gpt-3.5-turbo",
"messages": test_messages,
"max_tokens": 10,
"temperature": 0.1,
}
headers = {"Authorization": f"Bearer {model.api_key}"}
with httpx.Client(timeout=30.0) as client:
response = client.post(
f"{model.base_url}/chat/completions",
json=payload,
headers=headers
)
response.raise_for_status()
latency_ms = int((time.time() - start_time) * 1000)
result = response.json()
if result.get("choices"):
return LLMModelTestResponse(
success=True,
latency_ms=latency_ms,
message="Connection successful"
)
else:
return LLMModelTestResponse(
success=False,
latency_ms=latency_ms,
message="Unexpected response format"
)
except httpx.HTTPStatusError as e:
return LLMModelTestResponse(
success=False,
message=f"HTTP Error: {e.response.status_code} - {e.response.text[:200]}"
)
except Exception as e:
return LLMModelTestResponse(
success=False,
message=str(e)[:200]
)
@router.post("/{id}/chat")
def chat_with_llm(
id: str,
message: str,
system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
db: Session = Depends(get_db)
):
"""与LLM模型对话"""
model = db.query(LLMModel).filter(LLMModel.id == id).first()
if not model:
raise HTTPException(status_code=404, detail="LLM Model not found")
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": message})
payload = {
"model": model.model_name or "gpt-3.5-turbo",
"messages": messages,
"max_tokens": max_tokens or 1000,
"temperature": temperature if temperature is not None else model.temperature or 0.7,
}
headers = {"Authorization": f"Bearer {model.api_key}"}
try:
with httpx.Client(timeout=60.0) as client:
response = client.post(
f"{model.base_url}/chat/completions",
json=payload,
headers=headers
)
response.raise_for_status()
result = response.json()
if choice := result.get("choices", [{}])[0]:
return {
"success": True,
"reply": choice.get("message", {}).get("content", ""),
"usage": result.get("usage", {})
}
return {"success": False, "reply": "", "error": "No response"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

379
api/app/routers/tools.py Normal file
View File

@@ -0,0 +1,379 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import Optional, Dict, Any
import time
import uuid
import httpx
from ..db import get_db
from ..models import LLMModel, ASRModel
router = APIRouter(prefix="/tools", tags=["Tools & Autotest"])
# ============ Available Tools ============
TOOL_REGISTRY = {
"search": {
"name": "网络搜索",
"description": "搜索互联网获取最新信息",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "搜索关键词"}
},
"required": ["query"]
}
},
"calculator": {
"name": "计算器",
"description": "执行数学计算",
"parameters": {
"type": "object",
"properties": {
"expression": {"type": "string", "description": "数学表达式,如: 2 + 3 * 4"}
},
"required": ["expression"]
}
},
"weather": {
"name": "天气查询",
"description": "查询指定城市的天气",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string", "description": "城市名称"}
},
"required": ["city"]
}
},
"translate": {
"name": "翻译",
"description": "翻译文本到指定语言",
"parameters": {
"type": "object",
"properties": {
"text": {"type": "string", "description": "要翻译的文本"},
"target_lang": {"type": "string", "description": "目标语言,如: en, ja, ko"}
},
"required": ["text", "target_lang"]
}
},
"knowledge": {
"name": "知识库查询",
"description": "从知识库中检索相关信息",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "查询内容"},
"kb_id": {"type": "string", "description": "知识库ID"}
},
"required": ["query"]
}
},
"code_interpreter": {
"name": "代码执行",
"description": "安全地执行Python代码",
"parameters": {
"type": "object",
"properties": {
"code": {"type": "string", "description": "要执行的Python代码"}
},
"required": ["code"]
}
},
}
@router.get("/list")
def list_available_tools():
"""获取可用的工具列表"""
return {"tools": TOOL_REGISTRY}
@router.get("/list/{tool_id}")
def get_tool_detail(tool_id: str):
"""获取工具详情"""
if tool_id not in TOOL_REGISTRY:
raise HTTPException(status_code=404, detail="Tool not found")
return TOOL_REGISTRY[tool_id]
# ============ Autotest ============
class AutotestResult:
"""自动测试结果"""
def __init__(self):
self.id = str(uuid.uuid4())[:8]
self.started_at = time.time()
self.tests = []
self.summary = {"passed": 0, "failed": 0, "total": 0}
def add_test(self, name: str, passed: bool, message: str = "", duration_ms: int = 0):
self.tests.append({
"name": name,
"passed": passed,
"message": message,
"duration_ms": duration_ms
})
if passed:
self.summary["passed"] += 1
else:
self.summary["failed"] += 1
self.summary["total"] += 1
def to_dict(self):
return {
"id": self.id,
"started_at": self.started_at,
"duration_ms": int((time.time() - self.started_at) * 1000),
"tests": self.tests,
"summary": self.summary
}
@router.post("/autotest")
def run_autotest(
llm_model_id: Optional[str] = None,
asr_model_id: Optional[str] = None,
test_llm: bool = True,
test_asr: bool = True,
db: Session = Depends(get_db)
):
"""运行自动测试"""
result = AutotestResult()
# 测试 LLM 模型
if test_llm and llm_model_id:
_test_llm_model(db, llm_model_id, result)
# 测试 ASR 模型
if test_asr and asr_model_id:
_test_asr_model(db, asr_model_id, result)
# 测试 TTS 功能(需要时可添加)
if test_llm and not llm_model_id:
result.add_test(
"LLM Model Check",
False,
"No LLM model ID provided"
)
if test_asr and not asr_model_id:
result.add_test(
"ASR Model Check",
False,
"No ASR model ID provided"
)
return result.to_dict()
@router.post("/autotest/llm/{model_id}")
def autotest_llm_model(model_id: str, db: Session = Depends(get_db)):
"""测试单个LLM模型"""
result = AutotestResult()
_test_llm_model(db, model_id, result)
return result.to_dict()
@router.post("/autotest/asr/{model_id}")
def autotest_asr_model(model_id: str, db: Session = Depends(get_db)):
"""测试单个ASR模型"""
result = AutotestResult()
_test_asr_model(db, model_id, result)
return result.to_dict()
def _test_llm_model(db: Session, model_id: str, result: AutotestResult):
"""内部方法测试LLM模型"""
start_time = time.time()
# 1. 检查模型是否存在
model = db.query(LLMModel).filter(LLMModel.id == model_id).first()
duration_ms = int((time.time() - start_time) * 1000)
if not model:
result.add_test("Model Existence", False, f"Model {model_id} not found", duration_ms)
return
result.add_test("Model Existence", True, f"Found model: {model.name}", duration_ms)
# 2. 测试连接
test_start = time.time()
try:
test_messages = [{"role": "user", "content": "Reply with 'OK'."}]
payload = {
"model": model.model_name or "gpt-3.5-turbo",
"messages": test_messages,
"max_tokens": 10,
"temperature": 0.1,
}
headers = {"Authorization": f"Bearer {model.api_key}"}
with httpx.Client(timeout=30.0) as client:
response = client.post(
f"{model.base_url}/chat/completions",
json=payload,
headers=headers
)
response.raise_for_status()
result_text = response.json()
latency_ms = int((time.time() - test_start) * 1000)
if result_text.get("choices"):
result.add_test("API Connection", True, f"Latency: {latency_ms}ms", latency_ms)
else:
result.add_test("API Connection", False, "Empty response", latency_ms)
except Exception as e:
latency_ms = int((time.time() - test_start) * 1000)
result.add_test("API Connection", False, str(e)[:200], latency_ms)
# 3. 检查模型配置
if model.temperature is not None:
result.add_test("Temperature Setting", True, f"temperature={model.temperature}")
else:
result.add_test("Temperature Setting", True, "Using default")
# 4. 测试流式响应(可选)
if model.type == "text":
test_start = time.time()
try:
with httpx.Client(timeout=30.0) as client:
with client.stream(
"POST",
f"{model.base_url}/chat/completions",
json={
"model": model.model_name or "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Count from 1 to 3."}],
"stream": True,
},
headers=headers
) as response:
response.raise_for_status()
chunk_count = 0
for _ in response.iter_bytes():
chunk_count += 1
latency_ms = int((time.time() - test_start) * 1000)
result.add_test("Streaming Support", True, f"Received {chunk_count} chunks", latency_ms)
except Exception as e:
latency_ms = int((time.time() - test_start) * 1000)
result.add_test("Streaming Support", False, str(e)[:200], latency_ms)
def _test_asr_model(db: Session, model_id: str, result: AutotestResult):
"""内部方法测试ASR模型"""
start_time = time.time()
# 1. 检查模型是否存在
model = db.query(ASRModel).filter(ASRModel.id == model_id).first()
duration_ms = int((time.time() - start_time) * 1000)
if not model:
result.add_test("Model Existence", False, f"Model {model_id} not found", duration_ms)
return
result.add_test("Model Existence", True, f"Found model: {model.name}", duration_ms)
# 2. 测试配置
if model.hotwords:
result.add_test("Hotwords Config", True, f"Hotwords: {len(model.hotwords)} words")
else:
result.add_test("Hotwords Config", True, "No hotwords configured")
# 3. 测试API可用性
test_start = time.time()
try:
headers = {"Authorization": f"Bearer {model.api_key}"}
with httpx.Client(timeout=30.0) as client:
if model.vendor.lower() in ["siliconflow", "paraformer"]:
response = client.get(
f"{model.base_url}/asr",
headers=headers
)
elif model.vendor.lower() == "openai":
response = client.get(
f"{model.base_url}/audio/models",
headers=headers
)
else:
# 通用健康检查
response = client.get(
f"{model.base_url}/health",
headers=headers
)
latency_ms = int((time.time() - test_start) * 1000)
if response.status_code in [200, 405]: # 405 = method not allowed but endpoint exists
result.add_test("API Availability", True, f"Status: {response.status_code}", latency_ms)
else:
result.add_test("API Availability", False, f"Status: {response.status_code}", latency_ms)
except httpx.TimeoutException:
latency_ms = int((time.time() - test_start) * 1000)
result.add_test("API Availability", False, "Connection timeout", latency_ms)
except Exception as e:
latency_ms = int((time.time() - test_start) * 1000)
result.add_test("API Availability", False, str(e)[:200], latency_ms)
# 4. 检查语言配置
if model.language in ["zh", "en", "Multi-lingual"]:
result.add_test("Language Config", True, f"Language: {model.language}")
else:
result.add_test("Language Config", False, f"Unknown language: {model.language}")
# ============ Quick Health Check ============
@router.get("/health")
def health_check():
"""快速健康检查"""
return {
"status": "healthy",
"timestamp": time.time(),
"tools": list(TOOL_REGISTRY.keys())
}
@router.post("/test-message")
def send_test_message(
llm_model_id: str,
message: str = "Hello, this is a test message.",
db: Session = Depends(get_db)
):
"""发送测试消息"""
model = db.query(LLMModel).filter(LLMModel.id == llm_model_id).first()
if not model:
raise HTTPException(status_code=404, detail="LLM Model not found")
try:
payload = {
"model": model.model_name or "gpt-3.5-turbo",
"messages": [{"role": "user", "content": message}],
"max_tokens": 500,
"temperature": 0.7,
}
headers = {"Authorization": f"Bearer {model.api_key}"}
with httpx.Client(timeout=60.0) as client:
response = client.post(
f"{model.base_url}/chat/completions",
json=payload,
headers=headers
)
response.raise_for_status()
result = response.json()
reply = result.get("choices", [{}])[0].get("message", {}).get("content", "")
return {
"success": True,
"reply": reply,
"usage": result.get("usage", {})
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -1,24 +1,203 @@
from datetime import datetime from datetime import datetime
from enum import Enum
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel
# ============ Enums ============
class AssistantConfigMode(str, Enum):
PLATFORM = "platform"
DIFY = "dify"
FASTGPT = "fastgpt"
NONE = "none"
class LLMModelType(str, Enum):
TEXT = "text"
EMBEDDING = "embedding"
RERANK = "rerank"
class ASRLanguage(str, Enum):
ZH = "zh"
EN = "en"
MULTILINGUAL = "Multi-lingual"
class VoiceGender(str, Enum):
MALE = "Male"
FEMALE = "Female"
class CallRecordSource(str, Enum):
DEBUG = "debug"
EXTERNAL = "external"
class CallRecordStatus(str, Enum):
CONNECTED = "connected"
MISSED = "missed"
FAILED = "failed"
# ============ Voice ============ # ============ Voice ============
class VoiceBase(BaseModel): class VoiceBase(BaseModel):
name: str name: str
vendor: str vendor: str
gender: str gender: str # "Male" | "Female"
language: str language: str # "zh" | "en"
description: str description: str = ""
class VoiceCreate(VoiceBase):
model: str # 厂商语音模型标识
voice_key: str # 厂商voice_key
speed: float = 1.0
gain: int = 0
pitch: int = 0
enabled: bool = True
class VoiceUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
model: Optional[str] = None
voice_key: Optional[str] = None
speed: Optional[float] = None
gain: Optional[int] = None
pitch: Optional[int] = None
enabled: Optional[bool] = None
class VoiceOut(VoiceBase): class VoiceOut(VoiceBase):
id: str id: str
user_id: Optional[int] = None
model: Optional[str] = None
voice_key: Optional[str] = None
speed: float = 1.0
gain: int = 0
pitch: int = 0
enabled: bool = True
is_system: bool = False
created_at: Optional[datetime] = None
class Config: class Config:
from_attributes = True from_attributes = True
class VoicePreviewRequest(BaseModel):
text: str
speed: Optional[float] = None
gain: Optional[int] = None
pitch: Optional[int] = None
class VoicePreviewResponse(BaseModel):
success: bool
audio_url: Optional[str] = None
duration_ms: Optional[int] = None
error: Optional[str] = None
# ============ LLM Model ============
class LLMModelBase(BaseModel):
name: str
vendor: str
type: LLMModelType
base_url: str
api_key: str
model_name: Optional[str] = None
temperature: Optional[float] = None
context_length: Optional[int] = None
enabled: bool = True
class LLMModelCreate(LLMModelBase):
pass
class LLMModelUpdate(BaseModel):
name: Optional[str] = None
base_url: Optional[str] = None
api_key: Optional[str] = None
model_name: Optional[str] = None
temperature: Optional[float] = None
context_length: Optional[int] = None
enabled: Optional[bool] = None
class LLMModelOut(LLMModelBase):
id: str
user_id: int
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
class Config:
from_attributes = True
class LLMModelTestResponse(BaseModel):
success: bool
latency_ms: Optional[int] = None
message: Optional[str] = None
# ============ ASR Model ============
class ASRModelBase(BaseModel):
name: str
vendor: str
language: str # "zh" | "en" | "Multi-lingual"
base_url: str
api_key: str
model_name: Optional[str] = None
enabled: bool = True
class ASRModelCreate(ASRModelBase):
hotwords: List[str] = []
enable_punctuation: bool = True
enable_normalization: bool = True
class ASRModelUpdate(BaseModel):
name: Optional[str] = None
language: Optional[str] = None
base_url: Optional[str] = None
api_key: Optional[str] = None
model_name: Optional[str] = None
hotwords: Optional[List[str]] = None
enable_punctuation: Optional[bool] = None
enable_normalization: Optional[bool] = None
enabled: Optional[bool] = None
class ASRModelOut(ASRModelBase):
id: str
user_id: int
hotwords: List[str] = []
enable_punctuation: bool = True
enable_normalization: bool = True
created_at: Optional[datetime] = None
class Config:
from_attributes = True
class ASRTestRequest(BaseModel):
audio_url: Optional[str] = None
audio_data: Optional[str] = None # base64 encoded
class ASRTestResponse(BaseModel):
success: bool
transcript: Optional[str] = None
language: Optional[str] = None
confidence: Optional[float] = None
duration_ms: Optional[int] = None
latency_ms: Optional[int] = None
error: Optional[str] = None
# ============ Assistant ============ # ============ Assistant ============
class AssistantBase(BaseModel): class AssistantBase(BaseModel):
name: str name: str
@@ -34,25 +213,56 @@ class AssistantBase(BaseModel):
configMode: str = "platform" configMode: str = "platform"
apiUrl: Optional[str] = None apiUrl: Optional[str] = None
apiKey: Optional[str] = None apiKey: Optional[str] = None
# 模型关联
llmModelId: Optional[str] = None
asrModelId: Optional[str] = None
embeddingModelId: Optional[str] = None
rerankModelId: Optional[str] = None
class AssistantCreate(AssistantBase): class AssistantCreate(AssistantBase):
pass pass
class AssistantUpdate(AssistantBase): class AssistantUpdate(BaseModel):
name: Optional[str] = None name: Optional[str] = None
opener: Optional[str] = None
prompt: Optional[str] = None
knowledgeBaseId: Optional[str] = None
language: Optional[str] = None
voice: Optional[str] = None
speed: Optional[float] = None
hotwords: Optional[List[str]] = None
tools: Optional[List[str]] = None
interruptionSensitivity: Optional[int] = None
configMode: Optional[str] = None
apiUrl: Optional[str] = None
apiKey: Optional[str] = None
llmModelId: Optional[str] = None
asrModelId: Optional[str] = None
embeddingModelId: Optional[str] = None
rerankModelId: Optional[str] = None
class AssistantOut(AssistantBase): class AssistantOut(AssistantBase):
id: str id: str
callCount: int = 0 callCount: int = 0
created_at: Optional[datetime] = None created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
class Config: class Config:
from_attributes = True from_attributes = True
class AssistantStats(BaseModel):
assistant_id: str
total_calls: int = 0
connected_calls: int = 0
missed_calls: int = 0
avg_duration_seconds: float = 0.0
today_calls: int = 0
# ============ Knowledge Base ============ # ============ Knowledge Base ============
class KnowledgeDocument(BaseModel): class KnowledgeDocument(BaseModel):
id: str id: str
@@ -196,6 +406,7 @@ class TranscriptSegment(BaseModel):
endMs: int endMs: int
durationMs: Optional[int] = None durationMs: Optional[int] = None
audioUrl: Optional[str] = None audioUrl: Optional[str] = None
emotion: Optional[str] = None
class CallRecordCreate(BaseModel): class CallRecordCreate(BaseModel):
@@ -208,6 +419,9 @@ class CallRecordUpdate(BaseModel):
status: Optional[str] = None status: Optional[str] = None
summary: Optional[str] = None summary: Optional[str] = None
duration_seconds: Optional[int] = None duration_seconds: Optional[int] = None
ended_at: Optional[str] = None
cost: Optional[float] = None
metadata: Optional[dict] = None
class CallRecordOut(BaseModel): class CallRecordOut(BaseModel):
@@ -220,6 +434,9 @@ class CallRecordOut(BaseModel):
ended_at: Optional[str] = None ended_at: Optional[str] = None
duration_seconds: Optional[int] = None duration_seconds: Optional[int] = None
summary: Optional[str] = None summary: Optional[str] = None
cost: float = 0.0
metadata: dict = {}
created_at: Optional[datetime] = None
transcripts: List[TranscriptSegment] = [] transcripts: List[TranscriptSegment] = []
class Config: class Config:
@@ -246,6 +463,19 @@ class TranscriptOut(TranscriptCreate):
from_attributes = True from_attributes = True
# ============ History Stats ============
class HistoryStats(BaseModel):
total_calls: int = 0
connected_calls: int = 0
missed_calls: int = 0
failed_calls: int = 0
avg_duration_seconds: float = 0.0
total_cost: float = 0.0
by_status: dict = {}
by_source: dict = {}
daily_trend: List[dict] = []
# ============ Dashboard ============ # ============ Dashboard ============
class DashboardStats(BaseModel): class DashboardStats(BaseModel):
totalCalls: int totalCalls: int
@@ -269,3 +499,9 @@ class ListResponse(BaseModel):
page: int page: int
limit: int limit: int
list: List list: List
class SearchResult(BaseModel):
id: str
started_at: str
matched_content: Optional[str] = None

409
api/docs/asr.md Normal file
View File

@@ -0,0 +1,409 @@
# 语音识别 (ASR Model) API
语音识别 API 用于管理语音识别模型的配置和调用。
## 基础信息
| 项目 | 值 |
|------|-----|
| Base URL | `/api/v1/asr` |
| 认证方式 | Bearer Token (预留) |
---
## 数据模型
### ASRModel
```typescript
interface ASRModel {
id: string; // 模型唯一标识 (8位UUID)
user_id: number; // 所属用户ID
name: string; // 模型显示名称
vendor: string; // 供应商: "OpenAI" | "SiliconFlow" | "Paraformer" | 等
language: string; // 识别语言: "zh" | "en" | "Multi-lingual"
base_url: string; // API Base URL
api_key: string; // API Key
model_name?: string; // 模型名称,如 "whisper-1" | "paraformer-v2"
hotwords?: string[]; // 热词列表
enable_punctuation: boolean; // 是否启用标点
enable_normalization: boolean; // 是否启用文本规范化
enabled: boolean; // 是否启用
created_at: string;
}
```
---
## API 端点
### 1. 获取 ASR 模型列表
```http
GET /api/v1/asr
```
**Query Parameters:**
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|------|------|------|--------|------|
| language | string | 否 | - | 过滤语言: "zh" \| "en" \| "Multi-lingual" |
| enabled | boolean | 否 | - | 过滤启用状态 |
| page | int | 否 | 1 | 页码 |
| limit | int | 否 | 50 | 每页数量 |
**Response:**
```json
{
"total": 3,
"page": 1,
"limit": 50,
"list": [
{
"id": "abc12345",
"user_id": 1,
"name": "Whisper 多语种识别",
"vendor": "OpenAI",
"language": "Multi-lingual",
"base_url": "https://api.openai.com/v1",
"api_key": "sk-***",
"model_name": "whisper-1",
"enable_punctuation": true,
"enable_normalization": true,
"enabled": true,
"created_at": "2024-01-15T10:30:00Z"
},
{
"id": "def67890",
"user_id": 1,
"name": "SenseVoice 中文识别",
"vendor": "SiliconFlow",
"language": "zh",
"base_url": "https://api.siliconflow.cn/v1",
"api_key": "sf-***",
"model_name": "paraformer-v2",
"hotwords": ["小助手", "帮我"],
"enable_punctuation": true,
"enable_normalization": true,
"enabled": true,
"created_at": "2024-01-15T10:30:00Z"
}
]
}
```
---
### 2. 获取单个 ASR 模型详情
```http
GET /api/v1/asr/{id}
```
**Path Parameters:**
| 参数 | 类型 | 说明 |
|------|------|------|
| id | string | 模型ID |
**Response:**
```json
{
"id": "abc12345",
"user_id": 1,
"name": "Whisper 多语种识别",
"vendor": "OpenAI",
"language": "Multi-lingual",
"base_url": "https://api.openai.com/v1",
"api_key": "sk-***",
"model_name": "whisper-1",
"hotwords": [],
"enable_punctuation": true,
"enable_normalization": true,
"enabled": true,
"created_at": "2024-01-15T10:30:00Z"
}
```
---
### 3. 创建 ASR 模型
```http
POST /api/v1/asr
```
**Request Body:**
```json
{
"name": "SenseVoice 中文识别",
"vendor": "SiliconFlow",
"language": "zh",
"base_url": "https://api.siliconflow.cn/v1",
"api_key": "sk-your-api-key",
"model_name": "paraformer-v2",
"hotwords": ["小助手", "帮我"],
"enable_punctuation": true,
"enable_normalization": true,
"enabled": true
}
```
**Fields 说明:**
| 字段 | 类型 | 必填 | 说明 |
|------|------|------|------|
| name | string | 是 | 模型显示名称 |
| vendor | string | 是 | 供应商: "OpenAI" / "SiliconFlow" / "Paraformer" |
| language | string | 是 | 语言: "zh" / "en" / "Multi-lingual" |
| base_url | string | 是 | API Base URL |
| api_key | string | 是 | API Key |
| model_name | string | 否 | 模型名称 |
| hotwords | string[] | 否 | 热词列表,提升识别准确率 |
| enable_punctuation | boolean | 否 | 是否输出标点,默认 true |
| enable_normalization | boolean | 否 | 是否文本规范化,默认 true |
| enabled | boolean | 否 | 是否启用,默认 true |
| id | string | 否 | 指定模型ID默认自动生成 |
---
### 4. 更新 ASR 模型
```http
PUT /api/v1/asr/{id}
```
**Request Body:** (部分更新)
```json
{
"name": "Whisper-1 优化版",
"language": "zh",
"enable_punctuation": true,
"hotwords": ["新词1", "新词2"]
}
```
---
### 5. 删除 ASR 模型
```http
DELETE /api/v1/asr/{id}
```
**Response:**
```json
{
"message": "Deleted successfully"
}
```
---
### 6. 测试 ASR 模型
```http
POST /api/v1/asr/{id}/test
```
**Request Body:**
```json
{
"audio_url": "https://example.com/test-audio.wav"
}
```
或使用 Base64 编码的音频数据:
```json
{
"audio_data": "UklGRi..."
}
```
**Response (成功):**
```json
{
"success": true,
"transcript": "您好,请问有什么可以帮助您?",
"language": "zh",
"confidence": 0.95,
"latency_ms": 500
}
```
**Response (失败):**
```json
{
"success": false,
"error": "HTTP Error: 401 - Unauthorized"
}
```
---
### 7. 转写音频
```http
POST /api/v1/asr/{id}/transcribe
```
**Query Parameters:**
| 参数 | 类型 | 必填 | 说明 |
|------|------|------|------|
| audio_url | string | 否* | 音频文件URL |
| audio_data | string | 否* | Base64编码的音频数据 |
| hotwords | string[] | 否 | 热词列表 |
*二选一,至少提供一个
**Response:**
```json
{
"success": true,
"transcript": "您好,请问有什么可以帮助您?",
"language": "zh",
"confidence": 0.95
}
```
---
## Schema 定义
```python
from enum import Enum
from pydantic import BaseModel
from typing import Optional, List
from datetime import datetime
class ASRLanguage(str, Enum):
ZH = "zh"
EN = "en"
MULTILINGUAL = "Multi-lingual"
class ASRModelBase(BaseModel):
name: str
vendor: str
language: str # "zh" | "en" | "Multi-lingual"
base_url: str
api_key: str
model_name: Optional[str] = None
hotwords: List[str] = []
enable_punctuation: bool = True
enable_normalization: bool = True
enabled: bool = True
class ASRModelCreate(ASRModelBase):
id: Optional[str] = None
class ASRModelUpdate(BaseModel):
name: Optional[str] = None
language: Optional[str] = None
base_url: Optional[str] = None
api_key: Optional[str] = None
model_name: Optional[str] = None
hotwords: Optional[List[str]] = None
enable_punctuation: Optional[bool] = None
enable_normalization: Optional[bool] = None
enabled: Optional[bool] = None
class ASRModelOut(ASRModelBase):
id: str
user_id: int
created_at: datetime
class Config:
from_attributes = True
class ASRTestRequest(BaseModel):
audio_url: Optional[str] = None
audio_data: Optional[str] = None # base64 encoded
class ASRTestResponse(BaseModel):
success: bool
transcript: Optional[str] = None
language: Optional[str] = None
confidence: Optional[float] = None
latency_ms: Optional[int] = None
error: Optional[str] = None
```
---
## 供应商配置示例
### OpenAI Whisper
```json
{
"vendor": "OpenAI",
"base_url": "https://api.openai.com/v1",
"api_key": "sk-xxx",
"model_name": "whisper-1",
"language": "Multi-lingual",
"enable_punctuation": true,
"enable_normalization": true
}
```
### SiliconFlow Paraformer
```json
{
"vendor": "SiliconFlow",
"base_url": "https://api.siliconflow.cn/v1",
"api_key": "sf-xxx",
"model_name": "paraformer-v2",
"language": "zh",
"hotwords": ["产品名称", "公司名"],
"enable_punctuation": true,
"enable_normalization": true
}
```
---
## 单元测试
项目包含完整的单元测试,位于 `api/tests/test_asr.py`
### 测试用例概览
| 测试方法 | 说明 |
|----------|------|
| test_get_asr_models_empty | 空数据库获取测试 |
| test_create_asr_model | 创建模型测试 |
| test_create_asr_model_minimal | 最小数据创建测试 |
| test_get_asr_model_by_id | 获取单个模型测试 |
| test_get_asr_model_not_found | 获取不存在模型测试 |
| test_update_asr_model | 更新模型测试 |
| test_delete_asr_model | 删除模型测试 |
| test_list_asr_models_with_pagination | 分页测试 |
| test_filter_asr_models_by_language | 按语言过滤测试 |
| test_filter_asr_models_by_enabled | 按启用状态过滤测试 |
| test_create_asr_model_with_hotwords | 热词配置测试 |
| test_test_asr_model_siliconflow | SiliconFlow 供应商测试 |
| test_test_asr_model_openai | OpenAI 供应商测试 |
| test_different_asr_languages | 多语言测试 |
| test_different_asr_vendors | 多供应商测试 |
### 运行测试
```bash
# 运行 ASR 相关测试
pytest api/tests/test_asr.py -v
# 运行所有测试
pytest api/tests/ -v
```

View File

@@ -7,9 +7,9 @@
| 模块 | 文件 | 说明 | | 模块 | 文件 | 说明 |
|------|------|------| |------|------|------|
| 小助手 | [assistant.md](./assistant.md) | AI 助手管理 | | 小助手 | [assistant.md](./assistant.md) | AI 助手管理 |
| 模型接入 | [model-access.md](./model-access.md) | LLM/ASR/TTS 模型配置 | | LLM 模型 | [llm.md](./llm.md) | LLM 模型配置与管理 |
| 语音识别 | [speech-recognition.md](./speech-recognition.md) | ASR 模型配置 | | ASR 模型 | [asr.md](./asr.md) | 语音识别模型配置 |
| 声音资源 | [voice-resources.md](./voice-resources.md) | TTS 声音库管理 | | 工具与测试 | [tools.md](./tools.md) | 工具列表与自动测试 |
| 历史记录 | [history-records.md](./history-records.md) | 通话记录和转写 | | 历史记录 | [history-records.md](./history-records.md) | 通话记录和转写 |
--- ---

401
api/docs/llm.md Normal file
View File

@@ -0,0 +1,401 @@
# LLM 模型 (LLM Model) API
LLM 模型 API 用于管理大语言模型的配置和调用。
## 基础信息
| 项目 | 值 |
|------|-----|
| Base URL | `/api/v1/llm` |
| 认证方式 | Bearer Token (预留) |
---
## 数据模型
### LLMModel
```typescript
interface LLMModel {
id: string; // 模型唯一标识 (8位UUID)
user_id: number; // 所属用户ID
name: string; // 模型显示名称
vendor: string; // 供应商: "OpenAI" | "SiliconFlow" | "Dify" | "FastGPT" | 等
type: string; // 类型: "text" | "embedding" | "rerank"
base_url: string; // API Base URL
api_key: string; // API Key
model_name?: string; // 实际模型名称,如 "gpt-4o"
temperature?: number; // 温度参数 (0-2)
context_length?: int; // 上下文长度
enabled: boolean; // 是否启用
created_at: string;
updated_at: string;
}
```
---
## API 端点
### 1. 获取 LLM 模型列表
```http
GET /api/v1/llm
```
**Query Parameters:**
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|------|------|------|--------|------|
| model_type | string | 否 | - | 过滤类型: "text" \| "embedding" \| "rerank" |
| enabled | boolean | 否 | - | 过滤启用状态 |
| page | int | 否 | 1 | 页码 |
| limit | int | 否 | 50 | 每页数量 |
**Response:**
```json
{
"total": 5,
"page": 1,
"limit": 50,
"list": [
{
"id": "abc12345",
"user_id": 1,
"name": "GPT-4o",
"vendor": "OpenAI",
"type": "text",
"base_url": "https://api.openai.com/v1",
"api_key": "sk-***",
"model_name": "gpt-4o",
"temperature": 0.7,
"context_length": 128000,
"enabled": true,
"created_at": "2024-01-15T10:30:00Z",
"updated_at": "2024-01-15T10:30:00Z"
},
{
"id": "def67890",
"user_id": 1,
"name": "Embedding-3-Small",
"vendor": "OpenAI",
"type": "embedding",
"base_url": "https://api.openai.com/v1",
"api_key": "sk-***",
"model_name": "text-embedding-3-small",
"enabled": true
}
]
}
```
---
### 2. 获取单个 LLM 模型详情
```http
GET /api/v1/llm/{id}
```
**Path Parameters:**
| 参数 | 类型 | 说明 |
|------|------|------|
| id | string | 模型ID |
**Response:**
```json
{
"id": "abc12345",
"user_id": 1,
"name": "GPT-4o",
"vendor": "OpenAI",
"type": "text",
"base_url": "https://api.openai.com/v1",
"api_key": "sk-***",
"model_name": "gpt-4o",
"temperature": 0.7,
"context_length": 128000,
"enabled": true,
"created_at": "2024-01-15T10:30:00Z",
"updated_at": "2024-01-15T10:30:00Z"
}
```
---
### 3. 创建 LLM 模型
```http
POST /api/v1/llm
```
**Request Body:**
```json
{
"name": "GPT-4o",
"vendor": "OpenAI",
"type": "text",
"base_url": "https://api.openai.com/v1",
"api_key": "sk-your-api-key",
"model_name": "gpt-4o",
"temperature": 0.7,
"context_length": 128000,
"enabled": true
}
```
**Fields 说明:**
| 字段 | 类型 | 必填 | 说明 |
|------|------|------|------|
| name | string | 是 | 模型显示名称 |
| vendor | string | 是 | 供应商名称 |
| type | string | 是 | 模型类型: "text" / "embedding" / "rerank" |
| base_url | string | 是 | API Base URL |
| api_key | string | 是 | API Key |
| model_name | string | 否 | 实际模型名称 |
| temperature | number | 否 | 温度参数,默认 0.7 |
| context_length | int | 否 | 上下文长度 |
| enabled | boolean | 否 | 是否启用,默认 true |
| id | string | 否 | 指定模型ID默认自动生成 |
---
### 4. 更新 LLM 模型
```http
PUT /api/v1/llm/{id}
```
**Request Body:** (部分更新)
```json
{
"name": "GPT-4o-Updated",
"temperature": 0.8,
"enabled": false
}
```
---
### 5. 删除 LLM 模型
```http
DELETE /api/v1/llm/{id}
```
**Response:**
```json
{
"message": "Deleted successfully"
}
```
---
### 6. 测试 LLM 模型连接
```http
POST /api/v1/llm/{id}/test
```
**Response:**
```json
{
"success": true,
"latency_ms": 150,
"message": "Connection successful"
}
```
**错误响应:**
```json
{
"success": false,
"latency_ms": 200,
"message": "HTTP Error: 401 - Unauthorized"
}
```
---
### 7. 与 LLM 模型对话
```http
POST /api/v1/llm/{id}/chat
```
**Query Parameters:**
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|------|------|------|--------|------|
| message | string | 是 | - | 用户消息 |
| system_prompt | string | 否 | - | 系统提示词 |
| max_tokens | int | 否 | 1000 | 最大生成token数 |
| temperature | number | 否 | 模型配置 | 温度参数 |
**Response:**
```json
{
"success": true,
"reply": "您好!有什么可以帮助您的?",
"usage": {
"prompt_tokens": 20,
"completion_tokens": 15,
"total_tokens": 35
}
}
```
---
## Schema 定义
```python
from enum import Enum
from pydantic import BaseModel
from typing import Optional
from datetime import datetime
class LLMModelType(str, Enum):
TEXT = "text"
EMBEDDING = "embedding"
RERANK = "rerank"
class LLMModelBase(BaseModel):
name: str
vendor: str
type: LLMModelType
base_url: str
api_key: str
model_name: Optional[str] = None
temperature: Optional[float] = None
context_length: Optional[int] = None
enabled: bool = True
class LLMModelCreate(LLMModelBase):
id: Optional[str] = None
class LLMModelUpdate(BaseModel):
name: Optional[str] = None
vendor: Optional[str] = None
base_url: Optional[str] = None
api_key: Optional[str] = None
model_name: Optional[str] = None
temperature: Optional[float] = None
context_length: Optional[int] = None
enabled: Optional[bool] = None
class LLMModelOut(LLMModelBase):
id: str
user_id: int
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class LLMModelTestResponse(BaseModel):
success: bool
latency_ms: int
message: str
```
---
## 供应商配置示例
### OpenAI
```json
{
"vendor": "OpenAI",
"base_url": "https://api.openai.com/v1",
"api_key": "sk-xxx",
"model_name": "gpt-4o",
"type": "text",
"temperature": 0.7
}
```
### SiliconFlow
```json
{
"vendor": "SiliconFlow",
"base_url": "https://api.siliconflow.com/v1",
"api_key": "sf-xxx",
"model_name": "deepseek-v3",
"type": "text",
"temperature": 0.7
}
```
### Dify
```json
{
"vendor": "Dify",
"base_url": "https://your-dify.domain.com/v1",
"api_key": "app-xxx",
"model_name": "gpt-4",
"type": "text"
}
```
### Embedding 模型
```json
{
"vendor": "OpenAI",
"base_url": "https://api.openai.com/v1",
"api_key": "sk-xxx",
"model_name": "text-embedding-3-small",
"type": "embedding"
}
```
---
## 单元测试
项目包含完整的单元测试,位于 `api/tests/test_llm.py`
### 测试用例概览
| 测试方法 | 说明 |
|----------|------|
| test_get_llm_models_empty | 空数据库获取测试 |
| test_create_llm_model | 创建模型测试 |
| test_create_llm_model_minimal | 最小数据创建测试 |
| test_get_llm_model_by_id | 获取单个模型测试 |
| test_get_llm_model_not_found | 获取不存在模型测试 |
| test_update_llm_model | 更新模型测试 |
| test_delete_llm_model | 删除模型测试 |
| test_list_llm_models_with_pagination | 分页测试 |
| test_filter_llm_models_by_type | 按类型过滤测试 |
| test_filter_llm_models_by_enabled | 按启用状态过滤测试 |
| test_create_llm_model_with_all_fields | 全字段创建测试 |
| test_test_llm_model_success | 测试连接成功测试 |
| test_test_llm_model_failure | 测试连接失败测试 |
| test_different_llm_vendors | 多供应商测试 |
| test_embedding_llm_model | Embedding 模型测试 |
### 运行测试
```bash
# 运行 LLM 相关测试
pytest api/tests/test_llm.py -v
# 运行所有测试
pytest api/tests/ -v
```

445
api/docs/tools.md Normal file
View File

@@ -0,0 +1,445 @@
# 工具与自动测试 (Tools & Autotest) API
工具与自动测试 API 用于管理可用工具列表和自动测试功能。
## 基础信息
| 项目 | 值 |
|------|-----|
| Base URL | `/api/v1/tools` |
| 认证方式 | Bearer Token (预留) |
---
## 可用工具 (Tool Registry)
系统内置以下工具:
| 工具ID | 名称 | 说明 |
|--------|------|------|
| search | 网络搜索 | 搜索互联网获取最新信息 |
| calculator | 计算器 | 执行数学计算 |
| weather | 天气查询 | 查询指定城市的天气 |
| translate | 翻译 | 翻译文本到指定语言 |
| knowledge | 知识库查询 | 从知识库中检索相关信息 |
| code_interpreter | 代码执行 | 安全地执行Python代码 |
---
## API 端点
### 1. 获取可用工具列表
```http
GET /api/v1/tools/list
```
**Response:**
```json
{
"tools": {
"search": {
"name": "网络搜索",
"description": "搜索互联网获取最新信息",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "搜索关键词"}
},
"required": ["query"]
}
},
"calculator": {
"name": "计算器",
"description": "执行数学计算",
"parameters": {
"type": "object",
"properties": {
"expression": {"type": "string", "description": "数学表达式,如: 2 + 3 * 4"}
},
"required": ["expression"]
}
},
"weather": {
"name": "天气查询",
"description": "查询指定城市的天气",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string", "description": "城市名称"}
},
"required": ["city"]
}
},
"translate": {
"name": "翻译",
"description": "翻译文本到指定语言",
"parameters": {
"type": "object",
"properties": {
"text": {"type": "string", "description": "要翻译的文本"},
"target_lang": {"type": "string", "description": "目标语言,如: en, ja, ko"}
},
"required": ["text", "target_lang"]
}
},
"knowledge": {
"name": "知识库查询",
"description": "从知识库中检索相关信息",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "查询内容"},
"kb_id": {"type": "string", "description": "知识库ID"}
},
"required": ["query"]
}
},
"code_interpreter": {
"name": "代码执行",
"description": "安全地执行Python代码",
"parameters": {
"type": "object",
"properties": {
"code": {"type": "string", "description": "要执行的Python代码"}
},
"required": ["code"]
}
}
}
}
```
---
### 2. 获取工具详情
```http
GET /api/v1/tools/list/{tool_id}
```
**Path Parameters:**
| 参数 | 类型 | 说明 |
|------|------|------|
| tool_id | string | 工具ID |
**Response:**
```json
{
"name": "计算器",
"description": "执行数学计算",
"parameters": {
"type": "object",
"properties": {
"expression": {"type": "string", "description": "数学表达式,如: 2 + 3 * 4"}
},
"required": ["expression"]
}
}
```
**错误响应 (工具不存在):**
```json
{
"detail": "Tool not found"
}
```
---
### 3. 健康检查
```http
GET /api/v1/tools/health
```
**Response:**
```json
{
"status": "healthy",
"timestamp": 1705315200.123,
"tools": ["search", "calculator", "weather", "translate", "knowledge", "code_interpreter"]
}
```
---
## 自动测试 (Autotest)
### 4. 运行完整自动测试
```http
POST /api/v1/tools/autotest
```
**Query Parameters:**
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|------|------|------|--------|------|
| llm_model_id | string | 否 | - | LLM 模型ID |
| asr_model_id | string | 否 | - | ASR 模型ID |
| test_llm | boolean | 否 | true | 是否测试LLM |
| test_asr | boolean | 否 | true | 是否测试ASR |
**Response:**
```json
{
"id": "abc12345",
"started_at": 1705315200.0,
"duration_ms": 2500,
"tests": [
{
"name": "Model Existence",
"passed": true,
"message": "Found model: GPT-4o",
"duration_ms": 15
},
{
"name": "API Connection",
"passed": true,
"message": "Latency: 150ms",
"duration_ms": 150
},
{
"name": "Temperature Setting",
"passed": true,
"message": "temperature=0.7"
},
{
"name": "Streaming Support",
"passed": true,
"message": "Received 15 chunks",
"duration_ms": 800
}
],
"summary": {
"passed": 4,
"failed": 0,
"total": 4
}
}
```
---
### 5. 测试单个 LLM 模型
```http
POST /api/v1/tools/autotest/llm/{model_id}
```
**Path Parameters:**
| 参数 | 类型 | 说明 |
|------|------|------|
| model_id | string | LLM 模型ID |
**Response:**
```json
{
"id": "llm_test_001",
"started_at": 1705315200.0,
"duration_ms": 1200,
"tests": [
{
"name": "Model Existence",
"passed": true,
"message": "Found model: GPT-4o",
"duration_ms": 10
},
{
"name": "API Connection",
"passed": true,
"message": "Latency: 180ms",
"duration_ms": 180
},
{
"name": "Temperature Setting",
"passed": true,
"message": "temperature=0.7"
},
{
"name": "Streaming Support",
"passed": true,
"message": "Received 12 chunks",
"duration_ms": 650
}
],
"summary": {
"passed": 4,
"failed": 0,
"total": 4
}
}
```
---
### 6. 测试单个 ASR 模型
```http
POST /api/v1/tools/autotest/asr/{model_id}
```
**Path Parameters:**
| 参数 | 类型 | 说明 |
|------|------|------|
| model_id | string | ASR 模型ID |
**Response:**
```json
{
"id": "asr_test_001",
"started_at": 1705315200.0,
"duration_ms": 800,
"tests": [
{
"name": "Model Existence",
"passed": true,
"message": "Found model: Whisper-1",
"duration_ms": 8
},
{
"name": "Hotwords Config",
"passed": true,
"message": "Hotwords: 3 words"
},
{
"name": "API Availability",
"passed": true,
"message": "Status: 200",
"duration_ms": 250
},
{
"name": "Language Config",
"passed": true,
"message": "Language: zh"
}
],
"summary": {
"passed": 4,
"failed": 0,
"total": 4
}
}
```
---
### 7. 发送测试消息
```http
POST /api/v1/tools/test-message
```
**Query Parameters:**
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|------|------|------|--------|------|
| llm_model_id | string | 是 | - | LLM 模型ID |
| message | string | 否 | "Hello, this is a test message." | 测试消息 |
**Response:**
```json
{
"success": true,
"reply": "Hello! This is a test reply from GPT-4o.",
"usage": {
"prompt_tokens": 15,
"completion_tokens": 12,
"total_tokens": 27
}
}
```
**错误响应 (模型不存在):**
```json
{
"detail": "LLM Model not found"
}
```
---
## 测试结果结构
### AutotestResult
```typescript
interface AutotestResult {
id: string; // 测试ID
started_at: number; // 开始时间戳
duration_ms: number; // 总耗时(毫秒)
tests: TestCase[]; // 测试用例列表
summary: TestSummary; // 测试摘要
}
interface TestCase {
name: string; // 测试名称
passed: boolean; // 是否通过
message: string; // 测试消息
duration_ms: number; // 耗时(毫秒)
}
interface TestSummary {
passed: number; // 通过数量
failed: number; // 失败数量
total: number; // 总数量
}
```
---
## 测试项目说明
### LLM 模型测试项目
| 测试名称 | 说明 |
|----------|------|
| Model Existence | 检查模型是否存在于数据库 |
| API Connection | 测试 API 连接并测量延迟 |
| Temperature Setting | 检查温度配置 |
| Streaming Support | 测试流式响应支持 |
### ASR 模型测试项目
| 测试名称 | 说明 |
|----------|------|
| Model Existence | 检查模型是否存在于数据库 |
| Hotwords Config | 检查热词配置 |
| API Availability | 测试 API 可用性 |
| Language Config | 检查语言配置 |
---
## 单元测试
项目包含完整的单元测试,位于 `api/tests/test_tools.py`
### 测试用例概览
| 测试类 | 说明 |
|--------|------|
| TestToolsAPI | 工具列表、健康检查等基础功能测试 |
| TestAutotestAPI | 自动测试功能完整测试 |
### 运行测试
```bash
# 运行工具相关测试
pytest api/tests/test_tools.py -v
# 运行所有测试
pytest api/tests/ -v
```

View File

@@ -6,38 +6,381 @@ import sys
# 添加路径 # 添加路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from app.db import Base, engine from app.db import Base, engine, DATABASE_URL
from app.models import Voice from app.models import Voice, Assistant, KnowledgeBase, Workflow, LLMModel, ASRModel
def init_db(): def init_db():
"""创建所有表""" """创建所有表"""
# 确保 data 目录存在
data_dir = os.path.dirname(DATABASE_URL.replace("sqlite:///", ""))
os.makedirs(data_dir, exist_ok=True)
print("📦 创建数据库表...") print("📦 创建数据库表...")
Base.metadata.drop_all(bind=engine) # 删除旧表 Base.metadata.drop_all(bind=engine) # 删除旧表
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
print("✅ 数据库表创建完成") print("✅ 数据库表创建完成")
def init_default_voices(): def init_default_data():
"""初始化默认声音""" from sqlalchemy.orm import Session
from app.db import SessionLocal from app.db import SessionLocal
from app.models import Voice
db = SessionLocal() db = SessionLocal()
try: try:
# 检查是否已有数据
if db.query(Voice).count() == 0: if db.query(Voice).count() == 0:
# SiliconFlow CosyVoice 2.0 预设声音 (8个)
# 参考: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech
voices = [ voices = [
Voice(id="v1", name="Xiaoyun", vendor="Ali", gender="Female", language="zh", description="Gentle and professional."), # 男声 (Male Voices)
Voice(id="v2", name="Kevin", vendor="Volcano", gender="Male", language="en", description="Deep and authoritative."), Voice(id="alex", name="Alex", vendor="SiliconFlow", gender="Male", language="en",
Voice(id="v3", name="Abby", vendor="Minimax", gender="Female", language="en", description="Cheerful and lively."), description="Steady male voice.", is_system=True),
Voice(id="v4", name="Guang", vendor="Ali", gender="Male", language="zh", description="Standard newscast style."), Voice(id="benjamin", name="Benjamin", vendor="SiliconFlow", gender="Male", language="en",
Voice(id="v5", name="Doubao", vendor="Volcano", gender="Female", language="zh", description="Cute and young."), description="Deep male voice.", is_system=True),
Voice(id="charles", name="Charles", vendor="SiliconFlow", gender="Male", language="en",
description="Magnetic male voice.", is_system=True),
Voice(id="david", name="David", vendor="SiliconFlow", gender="Male", language="en",
description="Cheerful male voice.", is_system=True),
# 女声 (Female Voices)
Voice(id="anna", name="Anna", vendor="SiliconFlow", gender="Female", language="en",
description="Steady female voice.", is_system=True),
Voice(id="bella", name="Bella", vendor="SiliconFlow", gender="Female", language="en",
description="Passionate female voice.", is_system=True),
Voice(id="claire", name="Claire", vendor="SiliconFlow", gender="Female", language="en",
description="Gentle female voice.", is_system=True),
Voice(id="diana", name="Diana", vendor="SiliconFlow", gender="Female", language="en",
description="Cheerful female voice.", is_system=True),
# 中文方言 (Chinese Dialects) - 可选扩展
Voice(id="amador", name="Amador", vendor="SiliconFlow", gender="Male", language="zh",
description="Male voice with Spanish accent."),
Voice(id="aelora", name="Aelora", vendor="SiliconFlow", gender="Female", language="en",
description="Elegant female voice."),
Voice(id="aelwin", name="Aelwin", vendor="SiliconFlow", gender="Male", language="en",
description="Deep male voice."),
Voice(id="blooming", name="Blooming", vendor="SiliconFlow", gender="Female", language="en",
description="Fresh and clear female voice."),
Voice(id="elysia", name="Elysia", vendor="SiliconFlow", gender="Female", language="en",
description="Smooth and silky female voice."),
Voice(id="leo", name="Leo", vendor="SiliconFlow", gender="Male", language="en",
description="Young male voice."),
Voice(id="lin", name="Lin", vendor="SiliconFlow", gender="Female", language="zh",
description="Standard Chinese female voice."),
Voice(id="rose", name="Rose", vendor="SiliconFlow", gender="Female", language="en",
description="Soft and gentle female voice."),
Voice(id="shao", name="Shao", vendor="SiliconFlow", gender="Male", language="zh",
description="Deep Chinese male voice."),
Voice(id="sky", name="Sky", vendor="SiliconFlow", gender="Male", language="en",
description="Clear and bright male voice."),
Voice(id="ael西山", name="Ael西山", vendor="SiliconFlow", gender="Female", language="zh",
description="Female voice with Chinese dialect."),
] ]
for v in voices: for v in voices:
db.add(v) db.add(v)
db.commit() db.commit()
print("✅ 默认声音数据已初始化") print("✅ 默认声音数据已初始化 (SiliconFlow CosyVoice 2.0)")
else: finally:
print(" 声音数据已存在,跳过初始化") db.close()
def init_default_assistants():
"""初始化默认助手"""
from sqlalchemy.orm import Session
from app.db import SessionLocal
db = SessionLocal()
try:
if db.query(Assistant).count() == 0:
assistants = [
Assistant(
id="default",
user_id=1,
name="AI 助手",
call_count=0,
opener="你好我是AI助手有什么可以帮你的吗",
prompt="你是一个友好的AI助手请用简洁清晰的语言回答用户的问题。",
language="zh",
voice="anna",
speed=1.0,
hotwords=[],
tools=["search", "calculator"],
interruption_sensitivity=500,
config_mode="platform",
llm_model_id="deepseek-chat",
asr_model_id="paraformer-v2",
),
Assistant(
id="customer_service",
user_id=1,
name="客服助手",
call_count=0,
opener="您好,欢迎致电客服中心,请问有什么可以帮您?",
prompt="你是一个专业的客服人员,耐心解答客户问题,提供优质的服务体验。",
language="zh",
voice="bella",
speed=1.0,
hotwords=["客服", "投诉", "咨询"],
tools=["search"],
interruption_sensitivity=600,
config_mode="platform",
),
Assistant(
id="english_tutor",
user_id=1,
name="英语导师",
call_count=0,
opener="Hello! I'm your English learning companion. How can I help you today?",
prompt="You are a friendly English tutor. Help users practice English conversation and explain grammar points clearly.",
language="en",
voice="alex",
speed=1.0,
hotwords=["grammar", "vocabulary", "practice"],
tools=[],
interruption_sensitivity=400,
config_mode="platform",
),
]
for a in assistants:
db.add(a)
db.commit()
print("✅ 默认助手数据已初始化")
finally:
db.close()
def init_default_workflows():
"""初始化默认工作流"""
from sqlalchemy.orm import Session
from app.db import SessionLocal
from datetime import datetime
db = SessionLocal()
try:
if db.query(Workflow).count() == 0:
now = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
workflows = [
Workflow(
id="simple_conversation",
user_id=1,
name="简单对话",
node_count=2,
created_at=now,
updated_at=now,
global_prompt="处理简单的对话流程,用户问什么答什么。",
nodes=[
{"id": "1", "type": "start", "position": {"x": 100, "y": 100}, "data": {"label": "开始"}},
{"id": "2", "type": "ai_reply", "position": {"x": 300, "y": 100}, "data": {"label": "AI回复"}},
],
edges=[{"source": "1", "target": "2", "id": "e1-2"}],
),
Workflow(
id="voice_input_flow",
user_id=1,
name="语音输入流程",
node_count=4,
created_at=now,
updated_at=now,
global_prompt="处理语音输入的完整流程。",
nodes=[
{"id": "1", "type": "start", "position": {"x": 100, "y": 100}, "data": {"label": "开始"}},
{"id": "2", "type": "asr", "position": {"x": 250, "y": 100}, "data": {"label": "语音识别"}},
{"id": "3", "type": "llm", "position": {"x": 400, "y": 100}, "data": {"label": "LLM处理"}},
{"id": "4", "type": "tts", "position": {"x": 550, "y": 100}, "data": {"label": "语音合成"}},
],
edges=[
{"source": "1", "target": "2", "id": "e1-2"},
{"source": "2", "target": "3", "id": "e2-3"},
{"source": "3", "target": "4", "id": "e3-4"},
],
),
]
for w in workflows:
db.add(w)
db.commit()
print("✅ 默认工作流数据已初始化")
finally:
db.close()
def init_default_knowledge_bases():
"""初始化默认知识库"""
from sqlalchemy.orm import Session
from app.db import SessionLocal
db = SessionLocal()
try:
if db.query(KnowledgeBase).count() == 0:
kb = KnowledgeBase(
id="default_kb",
user_id=1,
name="默认知识库",
description="系统默认知识库,用于存储常见问题解答。",
embedding_model="text-embedding-3-small",
chunk_size=500,
chunk_overlap=50,
doc_count=0,
chunk_count=0,
status="active",
)
db.add(kb)
db.commit()
print("✅ 默认知识库已初始化")
finally:
db.close()
def init_default_llm_models():
"""初始化默认LLM模型"""
from sqlalchemy.orm import Session
from app.db import SessionLocal
db = SessionLocal()
try:
if db.query(LLMModel).count() == 0:
llm_models = [
LLMModel(
id="deepseek-chat",
user_id=1,
name="DeepSeek Chat",
vendor="SiliconFlow",
type="text",
base_url="https://api.deepseek.com",
api_key="YOUR_API_KEY", # 用户需替换
model_name="deepseek-chat",
temperature=0.7,
context_length=4096,
enabled=True,
),
LLMModel(
id="deepseek-reasoner",
user_id=1,
name="DeepSeek Reasoner",
vendor="SiliconFlow",
type="text",
base_url="https://api.deepseek.com",
api_key="YOUR_API_KEY",
model_name="deepseek-reasoner",
temperature=0.7,
context_length=4096,
enabled=True,
),
LLMModel(
id="gpt-4o",
user_id=1,
name="GPT-4o",
vendor="OpenAI",
type="text",
base_url="https://api.openai.com/v1",
api_key="YOUR_API_KEY",
model_name="gpt-4o",
temperature=0.7,
context_length=16384,
enabled=True,
),
LLMModel(
id="glm-4",
user_id=1,
name="GLM-4",
vendor="ZhipuAI",
type="text",
base_url="https://open.bigmodel.cn/api/paas/v4",
api_key="YOUR_API_KEY",
model_name="glm-4",
temperature=0.7,
context_length=8192,
enabled=True,
),
LLMModel(
id="text-embedding-3-small",
user_id=1,
name="Embedding 3 Small",
vendor="OpenAI",
type="embedding",
base_url="https://api.openai.com/v1",
api_key="YOUR_API_KEY",
model_name="text-embedding-3-small",
enabled=True,
),
]
for m in llm_models:
db.add(m)
db.commit()
print("✅ 默认LLM模型已初始化")
finally:
db.close()
def init_default_asr_models():
"""初始化默认ASR模型"""
from sqlalchemy.orm import Session
from app.db import SessionLocal
db = SessionLocal()
try:
if db.query(ASRModel).count() == 0:
asr_models = [
ASRModel(
id="paraformer-v2",
user_id=1,
name="Paraformer V2",
vendor="SiliconFlow",
language="zh",
base_url="https://api.siliconflow.cn/v1",
api_key="YOUR_API_KEY",
model_name="paraformer-v2",
hotwords=["人工智能", "机器学习"],
enable_punctuation=True,
enable_normalization=True,
enabled=True,
),
ASRModel(
id="paraformer-en",
user_id=1,
name="Paraformer English",
vendor="SiliconFlow",
language="en",
base_url="https://api.siliconflow.cn/v1",
api_key="YOUR_API_KEY",
model_name="paraformer-en",
hotwords=[],
enable_punctuation=True,
enable_normalization=True,
enabled=True,
),
ASRModel(
id="whisper-1",
user_id=1,
name="Whisper",
vendor="OpenAI",
language="Multi-lingual",
base_url="https://api.openai.com/v1",
api_key="YOUR_API_KEY",
model_name="whisper-1",
hotwords=[],
enable_punctuation=True,
enable_normalization=True,
enabled=True,
),
ASRModel(
id="sensevoice",
user_id=1,
name="SenseVoice",
vendor="SiliconFlow",
language="Multi-lingual",
base_url="https://api.siliconflow.cn/v1",
api_key="YOUR_API_KEY",
model_name="sensevoice",
hotwords=[],
enable_punctuation=True,
enable_normalization=True,
enabled=True,
),
]
for m in asr_models:
db.add(m)
db.commit()
print("✅ 默认ASR模型已初始化")
finally: finally:
db.close() db.close()
@@ -46,7 +389,12 @@ if __name__ == "__main__":
# 确保 data 目录存在 # 确保 data 目录存在
data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data") data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data")
os.makedirs(data_dir, exist_ok=True) os.makedirs(data_dir, exist_ok=True)
init_db() init_db()
init_default_voices() init_default_data()
init_default_assistants()
init_default_workflows()
init_default_knowledge_bases()
init_default_llm_models()
init_default_asr_models()
print("🎉 数据库初始化完成!") print("🎉 数据库初始化完成!")

View File

@@ -6,6 +6,9 @@ import os
from app.db import Base, engine from app.db import Base, engine
from app.routers import assistants, history, knowledge from app.routers import assistants, history, knowledge
# 配置
PORT = int(os.getenv("PORT", 8100))
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
@@ -52,22 +55,59 @@ def init_default_data():
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db import SessionLocal from app.db import SessionLocal
from app.models import Voice from app.models import Voice
db = SessionLocal() db = SessionLocal()
try: try:
# 检查是否已有数据 # 检查是否已有数据
if db.query(Voice).count() == 0: if db.query(Voice).count() == 0:
# 插入默认声音 # SiliconFlow CosyVoice 2.0 预设声音 (8个)
# 参考: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech
voices = [ voices = [
Voice(id="v1", name="Xiaoyun", vendor="Ali", gender="Female", language="zh", description="Gentle and professional."), # 男声 (Male Voices)
Voice(id="v2", name="Kevin", vendor="Volcano", gender="Male", language="en", description="Deep and authoritative."), Voice(id="alex", name="Alex", vendor="SiliconFlow", gender="Male", language="en",
Voice(id="v3", name="Abby", vendor="Minimax", gender="Female", language="en", description="Cheerful and lively."), description="Steady male voice.", is_system=True),
Voice(id="v4", name="Guang", vendor="Ali", gender="Male", language="zh", description="Standard newscast style."), Voice(id="benjamin", name="Benjamin", vendor="SiliconFlow", gender="Male", language="en",
Voice(id="v5", name="Doubao", vendor="Volcano", gender="Female", language="zh", description="Cute and young."), description="Deep male voice.", is_system=True),
Voice(id="charles", name="Charles", vendor="SiliconFlow", gender="Male", language="en",
description="Magnetic male voice.", is_system=True),
Voice(id="david", name="David", vendor="SiliconFlow", gender="Male", language="en",
description="Cheerful male voice.", is_system=True),
# 女声 (Female Voices)
Voice(id="anna", name="Anna", vendor="SiliconFlow", gender="Female", language="en",
description="Steady female voice.", is_system=True),
Voice(id="bella", name="Bella", vendor="SiliconFlow", gender="Female", language="en",
description="Passionate female voice.", is_system=True),
Voice(id="claire", name="Claire", vendor="SiliconFlow", gender="Female", language="en",
description="Gentle female voice.", is_system=True),
Voice(id="diana", name="Diana", vendor="SiliconFlow", gender="Female", language="en",
description="Cheerful female voice.", is_system=True),
# 中文方言 (Chinese Dialects) - 可选扩展
Voice(id="amador", name="Amador", vendor="SiliconFlow", gender="Male", language="zh",
description="Male voice with Spanish accent."),
Voice(id="aelora", name="Aelora", vendor="SiliconFlow", gender="Female", language="en",
description="Elegant female voice."),
Voice(id="aelwin", name="Aelwin", vendor="SiliconFlow", gender="Male", language="en",
description="Deep male voice."),
Voice(id="blooming", name="Blooming", vendor="SiliconFlow", gender="Female", language="en",
description="Fresh and clear female voice."),
Voice(id="elysia", name="Elysia", vendor="SiliconFlow", gender="Female", language="en",
description="Smooth and silky female voice."),
Voice(id="leo", name="Leo", vendor="SiliconFlow", gender="Male", language="en",
description="Young male voice."),
Voice(id="lin", name="Lin", vendor="SiliconFlow", gender="Female", language="zh",
description="Standard Chinese female voice."),
Voice(id="rose", name="Rose", vendor="SiliconFlow", gender="Female", language="en",
description="Soft and gentle female voice."),
Voice(id="shao", name="Shao", vendor="SiliconFlow", gender="Male", language="zh",
description="Deep Chinese male voice."),
Voice(id="sky", name="Sky", vendor="SiliconFlow", gender="Male", language="en",
description="Clear and bright male voice."),
Voice(id="ael西山", name="Ael西山", vendor="SiliconFlow", gender="Female", language="zh",
description="Female voice with Chinese dialect."),
] ]
for v in voices: for v in voices:
db.add(v) db.add(v)
db.commit() db.commit()
print("✅ 默认声音数据已初始化") print("✅ 默认声音数据已初始化 (SiliconFlow CosyVoice 2.0)")
finally: finally:
db.close() db.close()

8
api/pytest.ini Normal file
View File

@@ -0,0 +1,8 @@
[pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts = -v --tb=short
filterwarnings =
ignore::DeprecationWarning

14
api/run_tests.bat Normal file
View File

@@ -0,0 +1,14 @@
@echo off
REM Run API tests
cd /d "%~dp0"
REM Install test dependencies
echo Installing test dependencies...
pip install pytest pytest-cov -q
REM Run tests
echo Running tests...
pytest tests/ -v --tb=short
pause

1
api/tests/__init__.py Normal file
View File

@@ -0,0 +1 @@
# Tests package

137
api/tests/conftest.py Normal file
View File

@@ -0,0 +1,137 @@
"""Pytest fixtures for API tests"""
import os
import sys
import pytest
# Add api directory to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from app.db import Base, get_db
from app.main import app
# Use in-memory SQLite for testing
DATABASE_URL = "sqlite:///:memory:"
engine = create_engine(
DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@pytest.fixture(scope="function")
def db_session():
"""Create a fresh database session for each test"""
# Create all tables
Base.metadata.create_all(bind=engine)
session = TestingSessionLocal()
try:
yield session
finally:
session.close()
# Drop all tables after test
Base.metadata.drop_all(bind=engine)
@pytest.fixture(scope="function")
def client(db_session):
"""Create a test client with database dependency override"""
def override_get_db():
try:
yield db_session
finally:
pass
app.dependency_overrides[get_db] = override_get_db
with TestClient(app) as test_client:
yield test_client
app.dependency_overrides.clear()
@pytest.fixture
def sample_voice_data():
"""Sample voice data for testing"""
return {
"name": "Test Voice",
"vendor": "TestVendor",
"gender": "Female",
"language": "zh",
"description": "A test voice for unit testing",
"model": "test-model",
"voice_key": "test-key",
"speed": 1.0,
"gain": 0,
"pitch": 0,
"enabled": True
}
@pytest.fixture
def sample_assistant_data():
"""Sample assistant data for testing"""
return {
"name": "Test Assistant",
"opener": "Hello, welcome!",
"prompt": "You are a helpful assistant.",
"language": "zh",
"speed": 1.0,
"hotwords": ["test", "hello"],
"tools": [],
"configMode": "platform"
}
@pytest.fixture
def sample_call_record_data():
"""Sample call record data for testing"""
return {
"user_id": 1,
"assistant_id": None,
"source": "debug"
}
@pytest.fixture
def sample_llm_model_data():
"""Sample LLM model data for testing"""
return {
"id": "test-llm-001",
"name": "Test LLM Model",
"vendor": "TestVendor",
"type": "text",
"base_url": "https://api.test.com/v1",
"api_key": "test-api-key",
"model_name": "test-model",
"temperature": 0.7,
"context_length": 4096,
"enabled": True
}
@pytest.fixture
def sample_asr_model_data():
"""Sample ASR model data for testing"""
return {
"id": "test-asr-001",
"name": "Test ASR Model",
"vendor": "TestVendor",
"language": "zh",
"base_url": "https://api.test.com/v1",
"api_key": "test-api-key",
"model_name": "paraformer-v2",
"hotwords": ["测试", "语音"],
"enable_punctuation": True,
"enable_normalization": True,
"enabled": True
}

289
api/tests/test_asr.py Normal file
View File

@@ -0,0 +1,289 @@
"""Tests for ASR Model API endpoints"""
import pytest
from unittest.mock import patch, MagicMock
class TestASRModelAPI:
"""Test cases for ASR Model endpoints"""
def test_get_asr_models_empty(self, client):
"""Test getting ASR models when database is empty"""
response = client.get("/api/asr")
assert response.status_code == 200
data = response.json()
assert "total" in data
assert "list" in data
assert data["total"] == 0
def test_create_asr_model(self, client, sample_asr_model_data):
"""Test creating a new ASR model"""
response = client.post("/api/asr", json=sample_asr_model_data)
assert response.status_code == 200
data = response.json()
assert data["name"] == sample_asr_model_data["name"]
assert data["vendor"] == sample_asr_model_data["vendor"]
assert data["language"] == sample_asr_model_data["language"]
assert "id" in data
def test_create_asr_model_minimal(self, client):
"""Test creating an ASR model with minimal required data"""
data = {
"name": "Minimal ASR",
"vendor": "Test",
"language": "zh",
"base_url": "https://api.test.com",
"api_key": "test-key"
}
response = client.post("/api/asr", json=data)
assert response.status_code == 200
assert response.json()["name"] == "Minimal ASR"
def test_get_asr_model_by_id(self, client, sample_asr_model_data):
"""Test getting a specific ASR model by ID"""
# Create first
create_response = client.post("/api/asr", json=sample_asr_model_data)
model_id = create_response.json()["id"]
# Get by ID
response = client.get(f"/api/asr/{model_id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == model_id
assert data["name"] == sample_asr_model_data["name"]
def test_get_asr_model_not_found(self, client):
"""Test getting a non-existent ASR model"""
response = client.get("/api/asr/non-existent-id")
assert response.status_code == 404
def test_update_asr_model(self, client, sample_asr_model_data):
"""Test updating an ASR model"""
# Create first
create_response = client.post("/api/asr", json=sample_asr_model_data)
model_id = create_response.json()["id"]
# Update
update_data = {
"name": "Updated ASR Model",
"language": "en",
"enable_punctuation": False
}
response = client.put(f"/api/asr/{model_id}", json=update_data)
assert response.status_code == 200
data = response.json()
assert data["name"] == "Updated ASR Model"
assert data["language"] == "en"
assert data["enable_punctuation"] == False
def test_delete_asr_model(self, client, sample_asr_model_data):
"""Test deleting an ASR model"""
# Create first
create_response = client.post("/api/asr", json=sample_asr_model_data)
model_id = create_response.json()["id"]
# Delete
response = client.delete(f"/api/asr/{model_id}")
assert response.status_code == 200
# Verify deleted
get_response = client.get(f"/api/asr/{model_id}")
assert get_response.status_code == 404
def test_list_asr_models_with_pagination(self, client, sample_asr_model_data):
"""Test listing ASR models with pagination"""
# Create multiple models
for i in range(3):
data = sample_asr_model_data.copy()
data["id"] = f"test-asr-{i}"
data["name"] = f"ASR Model {i}"
client.post("/api/asr", json=data)
# Test pagination
response = client.get("/api/asr?page=1&limit=2")
assert response.status_code == 200
data = response.json()
assert data["total"] == 3
assert len(data["list"]) == 2
def test_filter_asr_models_by_language(self, client, sample_asr_model_data):
"""Test filtering ASR models by language"""
# Create models with different languages
for i, lang in enumerate(["zh", "en", "Multi-lingual"]):
data = sample_asr_model_data.copy()
data["id"] = f"test-asr-{lang}"
data["name"] = f"ASR {lang}"
data["language"] = lang
client.post("/api/asr", json=data)
# Filter by language
response = client.get("/api/asr?language=zh")
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1
for model in data["list"]:
assert model["language"] == "zh"
def test_filter_asr_models_by_enabled(self, client, sample_asr_model_data):
"""Test filtering ASR models by enabled status"""
# Create enabled and disabled models
data = sample_asr_model_data.copy()
data["id"] = "test-asr-enabled"
data["name"] = "Enabled ASR"
data["enabled"] = True
client.post("/api/asr", json=data)
data["id"] = "test-asr-disabled"
data["name"] = "Disabled ASR"
data["enabled"] = False
client.post("/api/asr", json=data)
# Filter by enabled
response = client.get("/api/asr?enabled=true")
assert response.status_code == 200
data = response.json()
for model in data["list"]:
assert model["enabled"] == True
def test_create_asr_model_with_hotwords(self, client):
"""Test creating an ASR model with hotwords"""
data = {
"id": "asr-hotwords",
"name": "ASR with Hotwords",
"vendor": "SiliconFlow",
"language": "zh",
"base_url": "https://api.siliconflow.cn/v1",
"api_key": "test-key",
"model_name": "paraformer-v2",
"hotwords": ["你好", "查询", "帮助"],
"enable_punctuation": True,
"enable_normalization": True
}
response = client.post("/api/asr", json=data)
assert response.status_code == 200
result = response.json()
assert result["hotwords"] == ["你好", "查询", "帮助"]
def test_create_asr_model_with_all_fields(self, client):
"""Test creating an ASR model with all fields"""
data = {
"id": "full-asr",
"name": "Full ASR Model",
"vendor": "SiliconFlow",
"language": "zh",
"base_url": "https://api.siliconflow.cn/v1",
"api_key": "sk-test",
"model_name": "paraformer-v2",
"hotwords": ["测试"],
"enable_punctuation": True,
"enable_normalization": True,
"enabled": True
}
response = client.post("/api/asr", json=data)
assert response.status_code == 200
result = response.json()
assert result["name"] == "Full ASR Model"
assert result["enable_punctuation"] == True
assert result["enable_normalization"] == True
@patch('httpx.Client')
def test_test_asr_model_siliconflow(self, mock_client_class, client, sample_asr_model_data):
"""Test testing an ASR model with SiliconFlow vendor"""
# Create model first
sample_asr_model_data["vendor"] = "SiliconFlow"
create_response = client.post("/api/asr", json=sample_asr_model_data)
model_id = create_response.json()["id"]
# Mock the HTTP response
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"results": [{"transcript": "测试文本", "language": "zh"}]
}
mock_response.raise_for_status = MagicMock()
mock_client.get.return_value = mock_response
mock_client.__enter__ = MagicMock(return_value=mock_client)
mock_client.__exit__ = MagicMock(return_value=False)
with patch('app.routers.asr.httpx.Client', return_value=mock_client):
response = client.post(f"/api/asr/{model_id}/test")
assert response.status_code == 200
data = response.json()
assert data["success"] == True
@patch('httpx.Client')
def test_test_asr_model_openai(self, mock_client_class, client, sample_asr_model_data):
"""Test testing an ASR model with OpenAI vendor"""
# Create model with OpenAI vendor
sample_asr_model_data["vendor"] = "OpenAI"
sample_asr_model_data["id"] = "test-asr-openai"
create_response = client.post("/api/asr", json=sample_asr_model_data)
model_id = create_response.json()["id"]
# Mock the HTTP response
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"text": "Test transcript"}
mock_response.raise_for_status = MagicMock()
mock_client.get.return_value = mock_response
mock_client.__enter__ = MagicMock(return_value=mock_client)
mock_client.__exit__ = MagicMock(return_value=False)
with patch('app.routers.asr.httpx.Client', return_value=mock_client):
response = client.post(f"/api/asr/{model_id}/test")
assert response.status_code == 200
@patch('httpx.Client')
def test_test_asr_model_failure(self, mock_client_class, client, sample_asr_model_data):
"""Test testing an ASR model with failed connection"""
# Create model first
create_response = client.post("/api/asr", json=sample_asr_model_data)
model_id = create_response.json()["id"]
# Mock HTTP error
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 401
mock_response.text = "Unauthorized"
mock_response.raise_for_status = MagicMock(side_effect=Exception("401 Unauthorized"))
mock_client.get.return_value = mock_response
mock_client.__enter__ = MagicMock(return_value=mock_client)
mock_client.__exit__ = MagicMock(return_value=False)
with patch('app.routers.asr.httpx.Client', return_value=mock_client):
response = client.post(f"/api/asr/{model_id}/test")
assert response.status_code == 200
data = response.json()
assert data["success"] == False
def test_different_asr_languages(self, client):
"""Test creating ASR models with different languages"""
for lang in ["zh", "en", "Multi-lingual"]:
data = {
"id": f"asr-lang-{lang}",
"name": f"ASR {lang}",
"vendor": "SiliconFlow",
"language": lang,
"base_url": "https://api.siliconflow.cn/v1",
"api_key": "test-key"
}
response = client.post("/api/asr", json=data)
assert response.status_code == 200
assert response.json()["language"] == lang
def test_different_asr_vendors(self, client):
"""Test creating ASR models with different vendors"""
vendors = ["SiliconFlow", "OpenAI", "Azure"]
for vendor in vendors:
data = {
"id": f"asr-vendor-{vendor.lower()}",
"name": f"ASR {vendor}",
"vendor": vendor,
"language": "zh",
"base_url": f"https://api.{vendor.lower()}.com/v1",
"api_key": "test-key"
}
response = client.post("/api/asr", json=data)
assert response.status_code == 200
assert response.json()["vendor"] == vendor

View File

@@ -0,0 +1,168 @@
"""Tests for Assistant API endpoints"""
import pytest
import uuid
class TestAssistantAPI:
"""Test cases for Assistant endpoints"""
def test_get_assistants_empty(self, client):
"""Test getting assistants when database is empty"""
response = client.get("/api/assistants")
assert response.status_code == 200
data = response.json()
assert "total" in data
assert "list" in data
def test_create_assistant(self, client, sample_assistant_data):
"""Test creating a new assistant"""
response = client.post("/api/assistants", json=sample_assistant_data)
assert response.status_code == 200
data = response.json()
assert data["name"] == sample_assistant_data["name"]
assert data["opener"] == sample_assistant_data["opener"]
assert data["prompt"] == sample_assistant_data["prompt"]
assert data["language"] == sample_assistant_data["language"]
assert "id" in data
assert data["callCount"] == 0
def test_create_assistant_minimal(self, client):
"""Test creating an assistant with minimal required data"""
data = {"name": "Minimal Assistant"}
response = client.post("/api/assistants", json=data)
assert response.status_code == 200
assert response.json()["name"] == "Minimal Assistant"
def test_get_assistant_by_id(self, client, sample_assistant_data):
"""Test getting a specific assistant by ID"""
# Create first
create_response = client.post("/api/assistants", json=sample_assistant_data)
assistant_id = create_response.json()["id"]
# Get by ID
response = client.get(f"/api/assistants/{assistant_id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == assistant_id
assert data["name"] == sample_assistant_data["name"]
def test_get_assistant_not_found(self, client):
"""Test getting a non-existent assistant"""
response = client.get("/api/assistants/non-existent-id")
assert response.status_code == 404
def test_update_assistant(self, client, sample_assistant_data):
"""Test updating an assistant"""
# Create first
create_response = client.post("/api/assistants", json=sample_assistant_data)
assistant_id = create_response.json()["id"]
# Update
update_data = {
"name": "Updated Assistant",
"prompt": "You are an updated assistant.",
"speed": 1.5
}
response = client.put(f"/api/assistants/{assistant_id}", json=update_data)
assert response.status_code == 200
data = response.json()
assert data["name"] == "Updated Assistant"
assert data["prompt"] == "You are an updated assistant."
assert data["speed"] == 1.5
def test_delete_assistant(self, client, sample_assistant_data):
"""Test deleting an assistant"""
# Create first
create_response = client.post("/api/assistants", json=sample_assistant_data)
assistant_id = create_response.json()["id"]
# Delete
response = client.delete(f"/api/assistants/{assistant_id}")
assert response.status_code == 200
# Verify deleted
get_response = client.get(f"/api/assistants/{assistant_id}")
assert get_response.status_code == 404
def test_list_assistants_with_pagination(self, client, sample_assistant_data):
"""Test listing assistants with pagination"""
# Create multiple assistants
for i in range(3):
data = sample_assistant_data.copy()
data["name"] = f"Assistant {i}"
client.post("/api/assistants", json=data)
# Test pagination
response = client.get("/api/assistants?page=1&limit=2")
assert response.status_code == 200
data = response.json()
assert data["total"] == 3
assert len(data["list"]) == 2
def test_create_assistant_with_voice(self, client, sample_assistant_data, sample_voice_data):
"""Test creating an assistant with a voice reference"""
# Create a voice first
voice_response = client.post("/api/voices", json=sample_voice_data)
voice_id = voice_response.json()["id"]
# Create assistant with voice
sample_assistant_data["voice"] = voice_id
response = client.post("/api/assistants", json=sample_assistant_data)
assert response.status_code == 200
assert response.json()["voice"] == voice_id
def test_create_assistant_with_knowledge_base(self, client, sample_assistant_data):
"""Test creating an assistant with knowledge base reference"""
# Note: This test assumes knowledge base doesn't exist
sample_assistant_data["knowledgeBaseId"] = "non-existent-kb"
response = client.post("/api/assistants", json=sample_assistant_data)
assert response.status_code == 200
assert response.json()["knowledgeBaseId"] == "non-existent-kb"
def test_assistant_with_model_references(self, client, sample_assistant_data):
"""Test creating assistant with model references"""
sample_assistant_data.update({
"llmModelId": "llm-001",
"asrModelId": "asr-001",
"embeddingModelId": "emb-001",
"rerankModelId": "rerank-001"
})
response = client.post("/api/assistants", json=sample_assistant_data)
assert response.status_code == 200
data = response.json()
assert data["llmModelId"] == "llm-001"
assert data["asrModelId"] == "asr-001"
assert data["embeddingModelId"] == "emb-001"
assert data["rerankModelId"] == "rerank-001"
def test_assistant_with_tools(self, client, sample_assistant_data):
"""Test creating assistant with tools"""
sample_assistant_data["tools"] = ["weather", "calculator", "search"]
response = client.post("/api/assistants", json=sample_assistant_data)
assert response.status_code == 200
assert response.json()["tools"] == ["weather", "calculator", "search"]
def test_assistant_with_hotwords(self, client, sample_assistant_data):
"""Test creating assistant with hotwords"""
sample_assistant_data["hotwords"] = ["hello", "help", "stop"]
response = client.post("/api/assistants", json=sample_assistant_data)
assert response.status_code == 200
assert response.json()["hotwords"] == ["hello", "help", "stop"]
def test_different_config_modes(self, client, sample_assistant_data):
"""Test creating assistants with different config modes"""
for mode in ["platform", "dify", "fastgpt", "none"]:
sample_assistant_data["name"] = f"Assistant {mode}"
sample_assistant_data["configMode"] = mode
response = client.post("/api/assistants", json=sample_assistant_data)
assert response.status_code == 200
assert response.json()["configMode"] == mode
def test_different_languages(self, client, sample_assistant_data):
"""Test creating assistants with different languages"""
for lang in ["zh", "en", "ja", "ko"]:
sample_assistant_data["name"] = f"Assistant {lang}"
sample_assistant_data["language"] = lang
response = client.post("/api/assistants", json=sample_assistant_data)
assert response.status_code == 200
assert response.json()["language"] == lang

236
api/tests/test_history.py Normal file
View File

@@ -0,0 +1,236 @@
"""Tests for History/Call Record API endpoints"""
import pytest
import time
class TestHistoryAPI:
"""Test cases for History/Call Record endpoints"""
def test_get_history_empty(self, client):
"""Test getting history when database is empty"""
response = client.get("/api/history")
assert response.status_code == 200
data = response.json()
assert "total" in data
assert "list" in data
def test_create_call_record(self, client, sample_call_record_data):
"""Test creating a new call record"""
response = client.post("/api/history", json=sample_call_record_data)
assert response.status_code == 200
data = response.json()
assert data["user_id"] == sample_call_record_data["user_id"]
assert data["source"] == sample_call_record_data["source"]
assert data["status"] == "connected"
assert "id" in data
assert "started_at" in data
def test_create_call_record_with_assistant(self, client, sample_assistant_data, sample_call_record_data):
"""Test creating a call record associated with an assistant"""
# Create assistant first
assistant_response = client.post("/api/assistants", json=sample_assistant_data)
assistant_id = assistant_response.json()["id"]
# Create call record with assistant
sample_call_record_data["assistant_id"] = assistant_id
response = client.post("/api/history", json=sample_call_record_data)
assert response.status_code == 200
assert response.json()["assistant_id"] == assistant_id
def test_get_call_record_by_id(self, client, sample_call_record_data):
"""Test getting a specific call record by ID"""
# Create first
create_response = client.post("/api/history", json=sample_call_record_data)
record_id = create_response.json()["id"]
# Get by ID
response = client.get(f"/api/history/{record_id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == record_id
def test_get_call_record_not_found(self, client):
"""Test getting a non-existent call record"""
response = client.get("/api/history/non-existent-id")
assert response.status_code == 404
def test_update_call_record(self, client, sample_call_record_data):
"""Test updating a call record"""
# Create first
create_response = client.post("/api/history", json=sample_call_record_data)
record_id = create_response.json()["id"]
# Update
update_data = {
"status": "completed",
"summary": "Test summary",
"duration_seconds": 120
}
response = client.put(f"/api/history/{record_id}", json=update_data)
assert response.status_code == 200
data = response.json()
assert data["status"] == "completed"
assert data["summary"] == "Test summary"
assert data["duration_seconds"] == 120
def test_delete_call_record(self, client, sample_call_record_data):
"""Test deleting a call record"""
# Create first
create_response = client.post("/api/history", json=sample_call_record_data)
record_id = create_response.json()["id"]
# Delete
response = client.delete(f"/api/history/{record_id}")
assert response.status_code == 200
# Verify deleted
get_response = client.get(f"/api/history/{record_id}")
assert get_response.status_code == 404
def test_add_transcript(self, client, sample_call_record_data):
"""Test adding a transcript to a call record"""
# Create call record first
create_response = client.post("/api/history", json=sample_call_record_data)
record_id = create_response.json()["id"]
# Add transcript
transcript_data = {
"turn_index": 0,
"speaker": "human",
"content": "Hello, I need help",
"start_ms": 0,
"end_ms": 3000,
"confidence": 0.95
}
response = client.post(
f"/api/history/{record_id}/transcripts",
json=transcript_data
)
assert response.status_code == 200
data = response.json()
assert data["turn_index"] == 0
assert data["speaker"] == "human"
assert data["content"] == "Hello, I need help"
def test_add_multiple_transcripts(self, client, sample_call_record_data):
"""Test adding multiple transcripts"""
# Create call record first
create_response = client.post("/api/history", json=sample_call_record_data)
record_id = create_response.json()["id"]
# Add human transcript
human_transcript = {
"turn_index": 0,
"speaker": "human",
"content": "Hello",
"start_ms": 0,
"end_ms": 1000
}
client.post(f"/api/history/{record_id}/transcripts", json=human_transcript)
# Add AI transcript
ai_transcript = {
"turn_index": 1,
"speaker": "ai",
"content": "Hello! How can I help you?",
"start_ms": 1500,
"end_ms": 4000
}
client.post(f"/api/history/{record_id}/transcripts", json=ai_transcript)
# Verify both transcripts exist
response = client.get(f"/api/history/{record_id}")
assert response.status_code == 200
data = response.json()
assert len(data["transcripts"]) == 2
def test_filter_history_by_status(self, client, sample_call_record_data):
"""Test filtering history by status"""
# Create records with different statuses
for i in range(2):
data = sample_call_record_data.copy()
data["status"] = "connected" if i % 2 == 0 else "missed"
client.post("/api/history", json=data)
# Filter by status
response = client.get("/api/history?status=connected")
assert response.status_code == 200
data = response.json()
for record in data["list"]:
assert record["status"] == "connected"
def test_filter_history_by_source(self, client, sample_call_record_data):
"""Test filtering history by source"""
sample_call_record_data["source"] = "external"
client.post("/api/history", json=sample_call_record_data)
response = client.get("/api/history?source=external")
assert response.status_code == 200
data = response.json()
for record in data["list"]:
assert record["source"] == "external"
def test_history_pagination(self, client, sample_call_record_data):
"""Test history pagination"""
# Create multiple records
for i in range(5):
data = sample_call_record_data.copy()
data["source"] = f"source-{i}"
client.post("/api/history", json=data)
# Test pagination
response = client.get("/api/history?page=1&limit=3")
assert response.status_code == 200
data = response.json()
assert data["total"] == 5
assert len(data["list"]) == 3
def test_transcript_with_emotion(self, client, sample_call_record_data):
"""Test adding transcript with emotion"""
# Create call record first
create_response = client.post("/api/history", json=sample_call_record_data)
record_id = create_response.json()["id"]
# Add transcript with emotion
transcript_data = {
"turn_index": 0,
"speaker": "ai",
"content": "Great news!",
"start_ms": 0,
"end_ms": 2000,
"emotion": "happy"
}
response = client.post(
f"/api/history/{record_id}/transcripts",
json=transcript_data
)
assert response.status_code == 200
assert response.json()["emotion"] == "happy"
def test_history_with_cost(self, client, sample_call_record_data):
"""Test creating history with cost"""
sample_call_record_data["cost"] = 0.05
response = client.post("/api/history", json=sample_call_record_data)
assert response.status_code == 200
assert response.json()["cost"] == 0.05
def test_history_search(self, client, sample_call_record_data):
"""Test searching history"""
# Create record
create_response = client.post("/api/history", json=sample_call_record_data)
record_id = create_response.json()["id"]
# Add transcript with searchable content
transcript_data = {
"turn_index": 0,
"speaker": "human",
"content": "I want to buy a product",
"start_ms": 0,
"end_ms": 3000
}
client.post(f"/api/history/{record_id}/transcripts", json=transcript_data)
# Search (this endpoint may not exist yet)
response = client.get("/api/history/search?q=product")
# This might return 404 if endpoint doesn't exist
assert response.status_code in [200, 404]

255
api/tests/test_knowledge.py Normal file
View File

@@ -0,0 +1,255 @@
"""Tests for Knowledge Base API endpoints"""
import pytest
import uuid
class TestKnowledgeAPI:
"""Test cases for Knowledge Base endpoints"""
def test_get_knowledge_bases_empty(self, client):
"""Test getting knowledge bases when database is empty"""
response = client.get("/api/knowledge/bases")
assert response.status_code == 200
data = response.json()
assert "total" in data
assert "list" in data
def test_create_knowledge_base(self, client):
"""Test creating a new knowledge base"""
data = {
"name": "Test Knowledge Base",
"description": "A test knowledge base",
"embeddingModel": "text-embedding-3-small",
"chunkSize": 500,
"chunkOverlap": 50
}
response = client.post("/api/knowledge/bases", json=data)
assert response.status_code == 200
data = response.json()
assert data["name"] == "Test Knowledge Base"
assert data["description"] == "A test knowledge base"
assert data["embeddingModel"] == "text-embedding-3-small"
assert "id" in data
assert data["docCount"] == 0
assert data["chunkCount"] == 0
assert data["status"] == "active"
def test_create_knowledge_base_minimal(self, client):
"""Test creating a knowledge base with minimal data"""
data = {"name": "Minimal KB"}
response = client.post("/api/knowledge/bases", json=data)
assert response.status_code == 200
assert response.json()["name"] == "Minimal KB"
def test_get_knowledge_base_by_id(self, client):
"""Test getting a specific knowledge base by ID"""
# Create first
create_data = {"name": "Test KB"}
create_response = client.post("/api/knowledge/bases", json=create_data)
kb_id = create_response.json()["id"]
# Get by ID
response = client.get(f"/api/knowledge/bases/{kb_id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == kb_id
assert data["name"] == "Test KB"
def test_get_knowledge_base_not_found(self, client):
"""Test getting a non-existent knowledge base"""
response = client.get("/api/knowledge/bases/non-existent-id")
assert response.status_code == 404
def test_update_knowledge_base(self, client):
"""Test updating a knowledge base"""
# Create first
create_data = {"name": "Original Name"}
create_response = client.post("/api/knowledge/bases", json=create_data)
kb_id = create_response.json()["id"]
# Update
update_data = {
"name": "Updated Name",
"description": "Updated description",
"chunkSize": 800
}
response = client.put(f"/api/knowledge/bases/{kb_id}", json=update_data)
assert response.status_code == 200
data = response.json()
assert data["name"] == "Updated Name"
assert data["description"] == "Updated description"
assert data["chunkSize"] == 800
def test_delete_knowledge_base(self, client):
"""Test deleting a knowledge base"""
# Create first
create_data = {"name": "To Delete"}
create_response = client.post("/api/knowledge/bases", json=create_data)
kb_id = create_response.json()["id"]
# Delete
response = client.delete(f"/api/knowledge/bases/{kb_id}")
assert response.status_code == 200
# Verify deleted
get_response = client.get(f"/api/knowledge/bases/{kb_id}")
assert get_response.status_code == 404
def test_upload_document(self, client):
"""Test uploading a document to knowledge base"""
# Create KB first
create_data = {"name": "Test KB for Docs"}
create_response = client.post("/api/knowledge/bases", json=create_data)
kb_id = create_response.json()["id"]
# Upload document
doc_data = {
"name": "test-document.txt",
"size": "1024",
"fileType": "txt",
"storageUrl": "https://storage.example.com/test-document.txt"
}
response = client.post(
f"/api/knowledge/bases/{kb_id}/documents",
json=doc_data
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "test-document.txt"
assert "id" in data
assert data["status"] == "pending"
def test_delete_document(self, client):
"""Test deleting a document from knowledge base"""
# Create KB first
create_data = {"name": "Test KB for Delete"}
create_response = client.post("/api/knowledge/bases", json=create_data)
kb_id = create_response.json()["id"]
# Upload document
doc_data = {"name": "to-delete.txt", "size": "100", "fileType": "txt"}
upload_response = client.post(
f"/api/knowledge/bases/{kb_id}/documents",
json=doc_data
)
doc_id = upload_response.json()["id"]
# Delete document
response = client.delete(
f"/api/knowledge/bases/{kb_id}/documents/{doc_id}"
)
assert response.status_code == 200
def test_index_document(self, client):
"""Test indexing a document"""
# Create KB first
create_data = {"name": "Test KB for Index"}
create_response = client.post("/api/knowledge/bases", json=create_data)
kb_id = create_response.json()["id"]
# Index document
index_data = {
"document_id": "doc-001",
"content": "This is the content to index. It contains important information about the product."
}
response = client.post(
f"/api/knowledge/bases/{kb_id}/documents/doc-001/index",
json=index_data
)
# This might return 200 or error depending on vector store implementation
assert response.status_code in [200, 500]
def test_search_knowledge(self, client):
"""Test searching knowledge base"""
# Create KB first
create_data = {"name": "Test KB for Search"}
create_response = client.post("/api/knowledge/bases", json=create_data)
kb_id = create_response.json()["id"]
# Search (this may fail without indexed content)
search_data = {
"query": "test query",
"kb_id": kb_id,
"nResults": 5
}
response = client.post("/api/knowledge/search", json=search_data)
# This might return 200 or error depending on implementation
assert response.status_code in [200, 500]
def test_get_knowledge_stats(self, client):
"""Test getting knowledge base statistics"""
# Create KB first
create_data = {"name": "Test KB for Stats"}
create_response = client.post("/api/knowledge/bases", json=create_data)
kb_id = create_response.json()["id"]
response = client.get(f"/api/knowledge/bases/{kb_id}/stats")
assert response.status_code == 200
data = response.json()
assert data["kb_id"] == kb_id
assert "docCount" in data
assert "chunkCount" in data
def test_knowledge_bases_pagination(self, client):
"""Test knowledge bases pagination"""
# Create multiple KBs
for i in range(5):
data = {"name": f"Knowledge Base {i}"}
client.post("/api/knowledge/bases", json=data)
# Test pagination
response = client.get("/api/knowledge/bases?page=1&limit=3")
assert response.status_code == 200
data = response.json()
assert data["total"] == 5
assert len(data["list"]) == 3
def test_different_embedding_models(self, client):
"""Test creating KB with different embedding models"""
models = [
"text-embedding-3-small",
"text-embedding-3-large",
"bge-small-zh"
]
for model in models:
data = {"name": f"KB with {model}", "embeddingModel": model}
response = client.post("/api/knowledge/bases", json=data)
assert response.status_code == 200
assert response.json()["embeddingModel"] == model
def test_different_chunk_sizes(self, client):
"""Test creating KB with different chunk configurations"""
configs = [
{"chunkSize": 500, "chunkOverlap": 50},
{"chunkSize": 1000, "chunkOverlap": 100},
{"chunkSize": 256, "chunkOverlap": 25}
]
for config in configs:
data = {"name": "Chunk Test KB", **config}
response = client.post("/api/knowledge/bases", json=data)
assert response.status_code == 200
def test_knowledge_base_with_documents(self, client):
"""Test creating KB and adding multiple documents"""
# Create KB
create_data = {"name": "KB with Multiple Docs"}
create_response = client.post("/api/knowledge/bases", json=create_data)
kb_id = create_response.json()["id"]
# Add multiple documents
for i in range(3):
doc_data = {
"name": f"document-{i}.txt",
"size": f"{1000 + i * 100}",
"fileType": "txt"
}
client.post(
f"/api/knowledge/bases/{kb_id}/documents",
json=doc_data
)
# Verify documents are listed
response = client.get(f"/api/knowledge/bases/{kb_id}")
assert response.status_code == 200
data = response.json()
assert len(data["documents"]) == 3

246
api/tests/test_llm.py Normal file
View File

@@ -0,0 +1,246 @@
"""Tests for LLM Model API endpoints"""
import pytest
from unittest.mock import patch, MagicMock
class TestLLMModelAPI:
"""Test cases for LLM Model endpoints"""
def test_get_llm_models_empty(self, client):
"""Test getting LLM models when database is empty"""
response = client.get("/api/llm")
assert response.status_code == 200
data = response.json()
assert "total" in data
assert "list" in data
assert data["total"] == 0
def test_create_llm_model(self, client, sample_llm_model_data):
"""Test creating a new LLM model"""
response = client.post("/api/llm", json=sample_llm_model_data)
assert response.status_code == 200
data = response.json()
assert data["name"] == sample_llm_model_data["name"]
assert data["vendor"] == sample_llm_model_data["vendor"]
assert data["type"] == sample_llm_model_data["type"]
assert data["base_url"] == sample_llm_model_data["base_url"]
assert "id" in data
def test_create_llm_model_minimal(self, client):
"""Test creating an LLM model with minimal required data"""
data = {
"name": "Minimal LLM",
"vendor": "Test",
"type": "text",
"base_url": "https://api.test.com",
"api_key": "test-key"
}
response = client.post("/api/llm", json=data)
assert response.status_code == 200
assert response.json()["name"] == "Minimal LLM"
def test_get_llm_model_by_id(self, client, sample_llm_model_data):
"""Test getting a specific LLM model by ID"""
# Create first
create_response = client.post("/api/llm", json=sample_llm_model_data)
model_id = create_response.json()["id"]
# Get by ID
response = client.get(f"/api/llm/{model_id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == model_id
assert data["name"] == sample_llm_model_data["name"]
def test_get_llm_model_not_found(self, client):
"""Test getting a non-existent LLM model"""
response = client.get("/api/llm/non-existent-id")
assert response.status_code == 404
def test_update_llm_model(self, client, sample_llm_model_data):
"""Test updating an LLM model"""
# Create first
create_response = client.post("/api/llm", json=sample_llm_model_data)
model_id = create_response.json()["id"]
# Update
update_data = {
"name": "Updated LLM Model",
"temperature": 0.5,
"context_length": 8192
}
response = client.put(f"/api/llm/{model_id}", json=update_data)
assert response.status_code == 200
data = response.json()
assert data["name"] == "Updated LLM Model"
assert data["temperature"] == 0.5
assert data["context_length"] == 8192
def test_delete_llm_model(self, client, sample_llm_model_data):
"""Test deleting an LLM model"""
# Create first
create_response = client.post("/api/llm", json=sample_llm_model_data)
model_id = create_response.json()["id"]
# Delete
response = client.delete(f"/api/llm/{model_id}")
assert response.status_code == 200
# Verify deleted
get_response = client.get(f"/api/llm/{model_id}")
assert get_response.status_code == 404
def test_list_llm_models_with_pagination(self, client, sample_llm_model_data):
"""Test listing LLM models with pagination"""
# Create multiple models
for i in range(3):
data = sample_llm_model_data.copy()
data["id"] = f"test-llm-{i}"
data["name"] = f"LLM Model {i}"
client.post("/api/llm", json=data)
# Test pagination
response = client.get("/api/llm?page=1&limit=2")
assert response.status_code == 200
data = response.json()
assert data["total"] == 3
assert len(data["list"]) == 2
def test_filter_llm_models_by_type(self, client, sample_llm_model_data):
"""Test filtering LLM models by type"""
# Create models with different types
for i, model_type in enumerate(["text", "embedding", "rerank"]):
data = sample_llm_model_data.copy()
data["id"] = f"test-llm-{model_type}"
data["name"] = f"LLM {model_type}"
data["type"] = model_type
client.post("/api/llm", json=data)
# Filter by type
response = client.get("/api/llm?model_type=text")
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1
for model in data["list"]:
assert model["type"] == "text"
def test_filter_llm_models_by_enabled(self, client, sample_llm_model_data):
"""Test filtering LLM models by enabled status"""
# Create enabled and disabled models
data = sample_llm_model_data.copy()
data["id"] = "test-llm-enabled"
data["name"] = "Enabled LLM"
data["enabled"] = True
client.post("/api/llm", json=data)
data["id"] = "test-llm-disabled"
data["name"] = "Disabled LLM"
data["enabled"] = False
client.post("/api/llm", json=data)
# Filter by enabled
response = client.get("/api/llm?enabled=true")
assert response.status_code == 200
data = response.json()
for model in data["list"]:
assert model["enabled"] == True
def test_create_llm_model_with_all_fields(self, client):
"""Test creating an LLM model with all fields"""
data = {
"id": "full-llm",
"name": "Full LLM Model",
"vendor": "OpenAI",
"type": "text",
"base_url": "https://api.openai.com/v1",
"api_key": "sk-test",
"model_name": "gpt-4",
"temperature": 0.8,
"context_length": 16384,
"enabled": True
}
response = client.post("/api/llm", json=data)
assert response.status_code == 200
result = response.json()
assert result["name"] == "Full LLM Model"
assert result["temperature"] == 0.8
assert result["context_length"] == 16384
@patch('httpx.Client')
def test_test_llm_model_success(self, mock_client_class, client, sample_llm_model_data):
"""Test testing an LLM model with successful connection"""
# Create model first
create_response = client.post("/api/llm", json=sample_llm_model_data)
model_id = create_response.json()["id"]
# Mock the HTTP response
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [{"message": {"content": "OK"}}]
}
mock_response.raise_for_status = MagicMock()
mock_client.post.return_value = mock_response
mock_client.__enter__ = MagicMock(return_value=mock_client)
mock_client.__exit__ = MagicMock(return_value=False)
with patch('app.routers.llm.httpx.Client', return_value=mock_client):
response = client.post(f"/api/llm/{model_id}/test")
assert response.status_code == 200
data = response.json()
assert data["success"] == True
@patch('httpx.Client')
def test_test_llm_model_failure(self, mock_client_class, client, sample_llm_model_data):
"""Test testing an LLM model with failed connection"""
# Create model first
create_response = client.post("/api/llm", json=sample_llm_model_data)
model_id = create_response.json()["id"]
# Mock HTTP error
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 401
mock_response.text = "Unauthorized"
mock_response.raise_for_status = MagicMock(side_effect=Exception("401 Unauthorized"))
mock_client.post.return_value = mock_response
mock_client.__enter__ = MagicMock(return_value=mock_client)
mock_client.__exit__ = MagicMock(return_value=False)
with patch('app.routers.llm.httpx.Client', return_value=mock_client):
response = client.post(f"/api/llm/{model_id}/test")
assert response.status_code == 200
data = response.json()
assert data["success"] == False
def test_different_llm_vendors(self, client):
"""Test creating LLM models with different vendors"""
vendors = ["OpenAI", "SiliconFlow", "ZhipuAI", "Anthropic"]
for vendor in vendors:
data = {
"id": f"test-{vendor.lower()}",
"name": f"Test {vendor}",
"vendor": vendor,
"type": "text",
"base_url": f"https://api.{vendor.lower()}.com/v1",
"api_key": "test-key"
}
response = client.post("/api/llm", json=data)
assert response.status_code == 200
assert response.json()["vendor"] == vendor
def test_embedding_llm_model(self, client):
"""Test creating an embedding LLM model"""
data = {
"id": "embedding-test",
"name": "Embedding Model",
"vendor": "OpenAI",
"type": "embedding",
"base_url": "https://api.openai.com/v1",
"api_key": "test-key",
"model_name": "text-embedding-3-small"
}
response = client.post("/api/llm", json=data)
assert response.status_code == 200
assert response.json()["type"] == "embedding"

267
api/tests/test_tools.py Normal file
View File

@@ -0,0 +1,267 @@
"""Tests for Tools & Autotest API endpoints"""
import pytest
from unittest.mock import patch, MagicMock
class TestToolsAPI:
"""Test cases for Tools endpoints"""
def test_list_available_tools(self, client):
"""Test listing all available tools"""
response = client.get("/api/tools/list")
assert response.status_code == 200
data = response.json()
assert "tools" in data
# Check for expected tools
tools = data["tools"]
assert "search" in tools
assert "calculator" in tools
assert "weather" in tools
def test_get_tool_detail(self, client):
"""Test getting a specific tool's details"""
response = client.get("/api/tools/list/search")
assert response.status_code == 200
data = response.json()
assert data["name"] == "网络搜索"
assert "parameters" in data
def test_get_tool_detail_not_found(self, client):
"""Test getting a non-existent tool"""
response = client.get("/api/tools/list/non-existent-tool")
assert response.status_code == 404
def test_health_check(self, client):
"""Test health check endpoint"""
response = client.get("/api/tools/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert "timestamp" in data
assert "tools" in data
class TestAutotestAPI:
"""Test cases for Autotest endpoints"""
def test_autotest_no_models(self, client):
"""Test autotest without specifying model IDs"""
response = client.post("/api/tools/autotest")
assert response.status_code == 200
data = response.json()
assert "id" in data
assert "tests" in data
assert "summary" in data
# Should have test failures since no models provided
assert data["summary"]["total"] > 0
def test_autotest_with_llm_model(self, client, sample_llm_model_data):
"""Test autotest with an LLM model"""
# Create an LLM model first
create_response = client.post("/api/llm", json=sample_llm_model_data)
model_id = create_response.json()["id"]
# Run autotest
response = client.post(f"/api/tools/autotest?llm_model_id={model_id}&test_asr=false")
assert response.status_code == 200
data = response.json()
assert "tests" in data
assert "summary" in data
def test_autotest_with_asr_model(self, client, sample_asr_model_data):
"""Test autotest with an ASR model"""
# Create an ASR model first
create_response = client.post("/api/asr", json=sample_asr_model_data)
model_id = create_response.json()["id"]
# Run autotest
response = client.post(f"/api/tools/autotest?asr_model_id={model_id}&test_llm=false")
assert response.status_code == 200
data = response.json()
assert "tests" in data
assert "summary" in data
def test_autotest_with_both_models(self, client, sample_llm_model_data, sample_asr_model_data):
"""Test autotest with both LLM and ASR models"""
# Create models
llm_response = client.post("/api/llm", json=sample_llm_model_data)
llm_id = llm_response.json()["id"]
asr_response = client.post("/api/asr", json=sample_asr_model_data)
asr_id = asr_response.json()["id"]
# Run autotest
response = client.post(
f"/api/tools/autotest?llm_model_id={llm_id}&asr_model_id={asr_id}"
)
assert response.status_code == 200
data = response.json()
assert "tests" in data
assert "summary" in data
@patch('httpx.Client')
def test_autotest_llm_model_success(self, mock_client_class, client, sample_llm_model_data):
"""Test autotest for a specific LLM model with successful connection"""
# Create an LLM model first
create_response = client.post("/api/llm", json=sample_llm_model_data)
model_id = create_response.json()["id"]
# Mock the HTTP response for successful connection
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [{"message": {"content": "OK"}}]
}
mock_response.raise_for_status = MagicMock()
mock_response.iter_bytes = MagicMock(return_value=[b'chunk1', b'chunk2'])
mock_client.post.return_value = mock_response
mock_client.__enter__ = MagicMock(return_value=mock_client)
mock_client.__exit__ = MagicMock(return_value=False)
with patch('app.routers.tools.httpx.Client', return_value=mock_client):
response = client.post(f"/api/tools/autotest/llm/{model_id}")
assert response.status_code == 200
data = response.json()
assert "tests" in data
assert "summary" in data
@patch('httpx.Client')
def test_autotest_asr_model_success(self, mock_client_class, client, sample_asr_model_data):
"""Test autotest for a specific ASR model with successful connection"""
# Create an ASR model first
create_response = client.post("/api/asr", json=sample_asr_model_data)
model_id = create_response.json()["id"]
# Mock the HTTP response for successful connection
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.raise_for_status = MagicMock()
mock_client.get.return_value = mock_response
mock_client.__enter__ = MagicMock(return_value=mock_client)
mock_client.__exit__ = MagicMock(return_value=False)
with patch('app.routers.tools.httpx.Client', return_value=mock_client):
response = client.post(f"/api/tools/autotest/asr/{model_id}")
assert response.status_code == 200
data = response.json()
assert "tests" in data
assert "summary" in data
def test_autotest_llm_model_not_found(self, client):
"""Test autotest for a non-existent LLM model"""
response = client.post("/api/tools/autotest/llm/non-existent-id")
assert response.status_code == 200
data = response.json()
# Should have a failure test
assert any(not t["passed"] for t in data["tests"])
def test_autotest_asr_model_not_found(self, client):
"""Test autotest for a non-existent ASR model"""
response = client.post("/api/tools/autotest/asr/non-existent-id")
assert response.status_code == 200
data = response.json()
# Should have a failure test
assert any(not t["passed"] for t in data["tests"])
@patch('httpx.Client')
def test_test_message_success(self, mock_client_class, client, sample_llm_model_data):
"""Test sending a test message to an LLM model"""
# Create an LLM model first
create_response = client.post("/api/llm", json=sample_llm_model_data)
model_id = create_response.json()["id"]
# Mock the HTTP response
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [{"message": {"content": "Hello! This is a test reply."}}],
"usage": {"total_tokens": 10}
}
mock_response.raise_for_status = MagicMock()
mock_client.post.return_value = mock_response
mock_client.__enter__ = MagicMock(return_value=mock_client)
mock_client.__exit__ = MagicMock(return_value=False)
with patch('app.routers.tools.httpx.Client', return_value=mock_client):
response = client.post(
f"/api/tools/test-message?llm_model_id={model_id}",
json={"message": "Hello!"}
)
assert response.status_code == 200
data = response.json()
assert data["success"] == True
assert "reply" in data
def test_test_message_model_not_found(self, client):
"""Test sending a test message to a non-existent model"""
response = client.post(
"/api/tools/test-message?llm_model_id=non-existent",
json={"message": "Hello!"}
)
assert response.status_code == 404
def test_autotest_result_structure(self, client):
"""Test that autotest results have the correct structure"""
response = client.post("/api/tools/autotest")
assert response.status_code == 200
data = response.json()
# Check required fields
assert "id" in data
assert "started_at" in data
assert "duration_ms" in data
assert "tests" in data
assert "summary" in data
# Check summary structure
assert "passed" in data["summary"]
assert "failed" in data["summary"]
assert "total" in data["summary"]
# Check test structure
if data["tests"]:
test = data["tests"][0]
assert "name" in test
assert "passed" in test
assert "message" in test
assert "duration_ms" in test
def test_tools_have_required_fields(self, client):
"""Test that all tools have required fields"""
response = client.get("/api/tools/list")
assert response.status_code == 200
data = response.json()
for tool_id, tool in data["tools"].items():
assert "name" in tool
assert "description" in tool
assert "parameters" in tool
# Check parameters structure
params = tool["parameters"]
assert "type" in params
assert "properties" in params
def test_calculator_tool_parameters(self, client):
"""Test calculator tool has correct parameters"""
response = client.get("/api/tools/list/calculator")
assert response.status_code == 200
data = response.json()
assert data["name"] == "计算器"
assert "expression" in data["parameters"]["properties"]
assert "required" in data["parameters"]
assert "expression" in data["parameters"]["required"]
def test_translate_tool_parameters(self, client):
"""Test translate tool has correct parameters"""
response = client.get("/api/tools/list/translate")
assert response.status_code == 200
data = response.json()
assert data["name"] == "翻译"
assert "text" in data["parameters"]["properties"]
assert "target_lang" in data["parameters"]["properties"]

132
api/tests/test_voices.py Normal file
View File

@@ -0,0 +1,132 @@
"""Tests for Voice API endpoints"""
import pytest
class TestVoiceAPI:
"""Test cases for Voice endpoints"""
def test_get_voices_empty(self, client):
"""Test getting voices when database is empty"""
response = client.get("/api/voices")
assert response.status_code == 200
data = response.json()
assert "total" in data
assert "list" in data
def test_create_voice(self, client, sample_voice_data):
"""Test creating a new voice"""
response = client.post("/api/voices", json=sample_voice_data)
assert response.status_code == 200
data = response.json()
assert data["name"] == sample_voice_data["name"]
assert data["vendor"] == sample_voice_data["vendor"]
assert data["gender"] == sample_voice_data["gender"]
assert data["language"] == sample_voice_data["language"]
assert "id" in data
def test_create_voice_minimal(self, client):
"""Test creating a voice with minimal data"""
data = {
"name": "Minimal Voice",
"vendor": "Test",
"gender": "Male",
"language": "en",
"description": ""
}
response = client.post("/api/voices", json=data)
assert response.status_code == 200
def test_get_voice_by_id(self, client, sample_voice_data):
"""Test getting a specific voice by ID"""
# Create first
create_response = client.post("/api/voices", json=sample_voice_data)
voice_id = create_response.json()["id"]
# Get by ID
response = client.get(f"/api/voices/{voice_id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == voice_id
assert data["name"] == sample_voice_data["name"]
def test_get_voice_not_found(self, client):
"""Test getting a non-existent voice"""
response = client.get("/api/voices/non-existent-id")
assert response.status_code == 404
def test_update_voice(self, client, sample_voice_data):
"""Test updating a voice"""
# Create first
create_response = client.post("/api/voices", json=sample_voice_data)
voice_id = create_response.json()["id"]
# Update
update_data = {"name": "Updated Voice", "speed": 1.5}
response = client.put(f"/api/voices/{voice_id}", json=update_data)
assert response.status_code == 200
data = response.json()
assert data["name"] == "Updated Voice"
assert data["speed"] == 1.5
def test_delete_voice(self, client, sample_voice_data):
"""Test deleting a voice"""
# Create first
create_response = client.post("/api/voices", json=sample_voice_data)
voice_id = create_response.json()["id"]
# Delete
response = client.delete(f"/api/voices/{voice_id}")
assert response.status_code == 200
# Verify deleted
get_response = client.get(f"/api/voices/{voice_id}")
assert get_response.status_code == 404
def test_list_voices_with_pagination(self, client, sample_voice_data):
"""Test listing voices with pagination"""
# Create multiple voices
for i in range(3):
data = sample_voice_data.copy()
data["name"] = f"Voice {i}"
client.post("/api/voices", json=data)
# Test pagination
response = client.get("/api/voices?page=1&limit=2")
assert response.status_code == 200
data = response.json()
assert data["total"] == 3
assert len(data["list"]) == 2
def test_filter_voices_by_vendor(self, client, sample_voice_data):
"""Test filtering voices by vendor"""
# Create voice with specific vendor
sample_voice_data["vendor"] = "FilterTestVendor"
client.post("/api/voices", json=sample_voice_data)
response = client.get("/api/voices?vendor=FilterTestVendor")
assert response.status_code == 200
data = response.json()
for voice in data["list"]:
assert voice["vendor"] == "FilterTestVendor"
def test_filter_voices_by_language(self, client, sample_voice_data):
"""Test filtering voices by language"""
sample_voice_data["language"] = "en"
client.post("/api/voices", json=sample_voice_data)
response = client.get("/api/voices?language=en")
assert response.status_code == 200
data = response.json()
for voice in data["list"]:
assert voice["language"] == "en"
def test_filter_voices_by_gender(self, client, sample_voice_data):
"""Test filtering voices by gender"""
sample_voice_data["gender"] = "Female"
client.post("/api/voices", json=sample_voice_data)
response = client.get("/api/voices?gender=Female")
assert response.status_code == 200
data = response.json()
for voice in data["list"]:
assert voice["gender"] == "Female"