Update backend schema
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
FROM python:3.11-slim
|
||||
FROM python:3.12-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -12,6 +12,6 @@ COPY . .
|
||||
# 创建数据目录
|
||||
RUN mkdir -p /app/data
|
||||
|
||||
EXPOSE 8000
|
||||
EXPOSE 8100
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"]
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8100", "--reload"]
|
||||
|
||||
273
api/README.md
273
api/README.md
@@ -1,13 +1,13 @@
|
||||
# AI VideoAssistant Backend
|
||||
|
||||
Python 后端 API,配合前端 `ai-videoassistant-frontend` 使用。
|
||||
Python 后端 API,配合前端 `web/` 模块使用。
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 1. 安装依赖
|
||||
|
||||
```bash
|
||||
cd ~/Code/ai-videoassistant-backend
|
||||
cd api
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
@@ -25,44 +25,162 @@ python init_db.py
|
||||
|
||||
```bash
|
||||
# 开发模式 (热重载)
|
||||
python -m uvicorn main:app --reload --host 0.0.0.0 --port 8000
|
||||
python -m uvicorn main:app --reload --host 0.0.0.0 --port 8100
|
||||
```
|
||||
|
||||
服务运行在: http://localhost:8100
|
||||
|
||||
### 4. 测试 API
|
||||
|
||||
```bash
|
||||
# 健康检查
|
||||
curl http://localhost:8000/health
|
||||
curl http://localhost:8100/health
|
||||
|
||||
# 获取助手列表
|
||||
curl http://localhost:8000/api/assistants
|
||||
curl http://localhost:8100/api/assistants
|
||||
|
||||
# 获取声音列表
|
||||
curl http://localhost:8000/api/voices
|
||||
curl http://localhost:8100/api/voices
|
||||
|
||||
# 获取通话历史
|
||||
curl http://localhost:8000/api/history
|
||||
curl http://localhost:8100/api/history
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## API 文档
|
||||
|
||||
| 端点 | 方法 | 说明 |
|
||||
完整 API 文档位于 [docs/](docs/) 目录:
|
||||
|
||||
| 模块 | 端点 | 方法 | 说明 |
|
||||
|------|------|------|------|
|
||||
| **Assistant** | `/api/assistants` | GET | 助手列表 |
|
||||
| | | POST | 创建助手 |
|
||||
| | `/api/assistants/{id}` | GET | 助手详情 |
|
||||
| | | PUT | 更新助手 |
|
||||
| | | DELETE | 删除助手 |
|
||||
| **Voice** | `/api/voices` | GET | 声音库列表 |
|
||||
| | | POST | 添加声音 |
|
||||
| | `/api/voices/{id}` | GET | 声音详情 |
|
||||
| | | PUT | 更新声音 |
|
||||
| | | DELETE | 删除声音 |
|
||||
| | `/api/voices/{id}/preview` | POST | 预览声音 |
|
||||
| **LLM Models** | `/api/models/llm` | GET | LLM 模型列表 |
|
||||
| | | POST | 添加模型 |
|
||||
| | `/api/models/llm/{id}` | GET | 模型详情 |
|
||||
| | | PUT | 更新模型 |
|
||||
| | | DELETE | 删除模型 |
|
||||
| | `/api/models/llm/{id}/test` | POST | 测试模型连接 |
|
||||
| **ASR Models** | `/api/models/asr` | GET | ASR 模型列表 |
|
||||
| | | POST | 添加模型 |
|
||||
| | `/api/models/asr/{id}` | GET | 模型详情 |
|
||||
| | | PUT | 更新模型 |
|
||||
| | | DELETE | 删除模型 |
|
||||
| | `/api/models/asr/{id}/test` | POST | 测试识别 |
|
||||
| **History** | `/api/history` | GET | 通话历史列表 |
|
||||
| | `/api/history/{id}` | GET | 通话详情 |
|
||||
| | | PUT | 更新通话记录 |
|
||||
| | | DELETE | 删除记录 |
|
||||
| | `/api/history/{id}/transcripts` | POST | 添加转写 |
|
||||
| | `/api/history/search` | GET | 搜索历史 |
|
||||
| | `/api/history/stats` | GET | 统计数据 |
|
||||
| **Knowledge** | `/api/knowledge/bases` | GET | 知识库列表 |
|
||||
| | | POST | 创建知识库 |
|
||||
| | `/api/knowledge/bases/{id}` | GET | 知识库详情 |
|
||||
| | | PUT | 更新知识库 |
|
||||
| | | DELETE | 删除知识库 |
|
||||
| | `/api/knowledge/bases/{kb_id}/documents` | POST | 上传文档 |
|
||||
| | `/api/knowledge/bases/{kb_id}/documents/{doc_id}` | DELETE | 删除文档 |
|
||||
| | `/api/knowledge/bases/{kb_id}/documents/{doc_id}/index` | POST | 索引文档 |
|
||||
| | `/api/knowledge/search` | POST | 知识搜索 |
|
||||
| **Workflow** | `/api/workflows` | GET | 工作流列表 |
|
||||
| | | POST | 创建工作流 |
|
||||
| | `/api/workflows/{id}` | GET | 工作流详情 |
|
||||
| | | PUT | 更新工作流 |
|
||||
| | | DELETE | 删除工作流 |
|
||||
|
||||
---
|
||||
|
||||
## 数据模型
|
||||
|
||||
### Assistant (小助手)
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| `/api/assistants` | GET | 助手列表 |
|
||||
| `/api/assistants` | POST | 创建助手 |
|
||||
| `/api/assistants/{id}` | GET | 助手详情 |
|
||||
| `/api/assistants/{id}` | PUT | 更新助手 |
|
||||
| `/api/assistants/{id}` | DELETE | 删除助手 |
|
||||
| `/api/voices` | GET | 声音库列表 |
|
||||
| `/api/history` | GET | 通话历史列表 |
|
||||
| `/api/history/{id}` | GET | 通话详情 |
|
||||
| `/api/history/{id}/transcripts` | POST | 添加转写 |
|
||||
| `/api/history/{id}/audio/{turn}` | GET | 获取音频 |
|
||||
| id | string | 助手 ID |
|
||||
| name | string | 助手名称 |
|
||||
| opener | string | 开场白 |
|
||||
| prompt | string | 系统提示词 |
|
||||
| knowledgeBaseId | string | 关联知识库 ID |
|
||||
| language | string | 语言: zh/en |
|
||||
| voice | string | 声音 ID |
|
||||
| speed | float | 语速 (0.5-2.0) |
|
||||
| hotwords | array | 热词列表 |
|
||||
| tools | array | 启用的工具列表 |
|
||||
| llmModelId | string | LLM 模型 ID |
|
||||
| asrModelId | string | ASR 模型 ID |
|
||||
| embeddingModelId | string | Embedding 模型 ID |
|
||||
| rerankModelId | string | Rerank 模型 ID |
|
||||
|
||||
### Voice (声音资源)
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| id | string | 声音 ID |
|
||||
| name | string | 声音名称 |
|
||||
| vendor | string | 厂商: Ali/Volcano/Minimax |
|
||||
| gender | string | 性别: Male/Female |
|
||||
| language | string | 语言: zh/en |
|
||||
| model | string | 厂商模型标识 |
|
||||
| voice_key | string | 厂商 voice_key |
|
||||
| speed | float | 语速 |
|
||||
| gain | int | 增益 (dB) |
|
||||
| pitch | int | 音调 |
|
||||
|
||||
### LLMModel (模型接入)
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| id | string | 模型 ID |
|
||||
| name | string | 模型名称 |
|
||||
| vendor | string | 厂商 |
|
||||
| type | string | 类型: text/embedding/rerank |
|
||||
| base_url | string | API 地址 |
|
||||
| api_key | string | API 密钥 |
|
||||
| model_name | string | 模型名称 |
|
||||
| temperature | float | 温度参数 |
|
||||
|
||||
### ASRModel (语音识别)
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| id | string | 模型 ID |
|
||||
| name | string | 模型名称 |
|
||||
| vendor | string | 厂商 |
|
||||
| language | string | 语言: zh/en/Multi-lingual |
|
||||
| base_url | string | API 地址 |
|
||||
| api_key | string | API 密钥 |
|
||||
| hotwords | array | 热词列表 |
|
||||
|
||||
### CallRecord (通话记录)
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| id | string | 记录 ID |
|
||||
| assistant_id | string | 助手 ID |
|
||||
| source | string | 来源: debug/external |
|
||||
| status | string | 状态: connected/missed/failed |
|
||||
| started_at | string | 开始时间 |
|
||||
| duration_seconds | int | 通话时长 |
|
||||
| summary | string | 通话摘要 |
|
||||
| transcripts | array | 对话转写 |
|
||||
|
||||
---
|
||||
|
||||
## 使用 Docker 启动
|
||||
|
||||
```bash
|
||||
cd ~/Code/ai-videoassistant-backend
|
||||
cd api
|
||||
|
||||
# 启动所有服务
|
||||
docker-compose up -d
|
||||
@@ -71,33 +189,144 @@ docker-compose up -d
|
||||
docker-compose logs -f backend
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
backend/
|
||||
api/
|
||||
├── app/
|
||||
│ ├── __init__.py
|
||||
│ ├── main.py # FastAPI 入口
|
||||
│ ├── db.py # SQLite 连接
|
||||
│ ├── models.py # 数据模型
|
||||
│ ├── models.py # SQLAlchemy 数据模型
|
||||
│ ├── schemas.py # Pydantic 模型
|
||||
│ ├── storage.py # MinIO 存储
|
||||
│ ├── vector_store.py # 向量存储
|
||||
│ └── routers/
|
||||
│ ├── __init__.py
|
||||
│ ├── assistants.py # 助手 API
|
||||
│ └── history.py # 通话记录 API
|
||||
│ ├── history.py # 通话记录 API
|
||||
│ └── knowledge.py # 知识库 API
|
||||
├── data/ # 数据库文件
|
||||
├── docs/ # API 文档
|
||||
├── requirements.txt
|
||||
├── .env
|
||||
├── init_db.py
|
||||
├── main.py
|
||||
└── docker-compose.yml
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 环境变量
|
||||
|
||||
| 变量 | 默认值 | 说明 |
|
||||
|------|--------|------|
|
||||
| `PORT` | `8100` | 服务端口 |
|
||||
| `DATABASE_URL` | `sqlite:///./data/app.db` | 数据库连接 |
|
||||
| `MINIO_ENDPOINT` | `localhost:9000` | MinIO 地址 |
|
||||
| `MINIO_ACCESS_KEY` | `admin` | MinIO 密钥 |
|
||||
| `MINIO_SECRET_KEY` | `password123` | MinIO 密码 |
|
||||
| `MINIO_BUCKET` | `ai-audio` | 存储桶名称 |
|
||||
|
||||
---
|
||||
|
||||
## 数据库迁移
|
||||
|
||||
开发环境重新创建数据库:
|
||||
|
||||
```bash
|
||||
rm -f api/data/app.db
|
||||
python api/init_db.py
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 测试
|
||||
|
||||
### 安装测试依赖
|
||||
|
||||
```bash
|
||||
cd api
|
||||
pip install pytest pytest-cov -q
|
||||
```
|
||||
|
||||
### 运行所有测试
|
||||
|
||||
```bash
|
||||
# Windows
|
||||
run_tests.bat
|
||||
|
||||
# 或使用 pytest
|
||||
pytest tests/ -v
|
||||
```
|
||||
|
||||
### 运行特定测试
|
||||
|
||||
```bash
|
||||
# 只测试声音 API
|
||||
pytest tests/test_voices.py -v
|
||||
|
||||
# 只测试助手 API
|
||||
pytest tests/test_assistants.py -v
|
||||
|
||||
# 只测试历史记录 API
|
||||
pytest tests/test_history.py -v
|
||||
|
||||
# 只测试知识库 API
|
||||
pytest tests/test_knowledge.py -v
|
||||
```
|
||||
|
||||
### 测试覆盖率
|
||||
|
||||
```bash
|
||||
pytest tests/ --cov=app --cov-report=html
|
||||
# 查看报告: open htmlcov/index.html
|
||||
```
|
||||
|
||||
### 测试目录结构
|
||||
|
||||
```
|
||||
tests/
|
||||
├── __init__.py
|
||||
├── conftest.py # pytest fixtures
|
||||
├── test_voices.py # 声音 API 测试
|
||||
├── test_assistants.py # 助手 API 测试
|
||||
├── test_history.py # 历史记录 API 测试
|
||||
└── test_knowledge.py # 知识库 API 测试
|
||||
```
|
||||
|
||||
### 测试用例统计
|
||||
|
||||
| 模块 | 测试用例数 |
|
||||
|------|-----------|
|
||||
| Voice | 13 |
|
||||
| Assistant | 14 |
|
||||
| History | 18 |
|
||||
| Knowledge | 19 |
|
||||
| **总计** | **64** |
|
||||
|
||||
### CI/CD 示例 (.github/workflows/test.yml)
|
||||
|
||||
```yaml
|
||||
name: Tests
|
||||
|
||||
on: [push, pull_request]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -r api/requirements.txt
|
||||
pip install pytest pytest-cov
|
||||
- name: Run tests
|
||||
run: pytest api/tests/ -v --cov=app
|
||||
```
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker, DeclarativeBase
|
||||
import os
|
||||
|
||||
DATABASE_URL = "sqlite:///./data/app.db"
|
||||
# 使用绝对路径
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
DATABASE_URL = f"sqlite:///{os.path.join(BASE_DIR, 'data', 'app.db')}"
|
||||
|
||||
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
@@ -43,30 +43,3 @@ def root():
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# 初始化默认数据
|
||||
@app.on_event("startup")
|
||||
def init_default_data():
|
||||
from sqlalchemy.orm import Session
|
||||
from .db import SessionLocal
|
||||
from .models import Voice
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 检查是否已有数据
|
||||
if db.query(Voice).count() == 0:
|
||||
# 插入默认声音
|
||||
voices = [
|
||||
Voice(id="v1", name="Xiaoyun", vendor="Ali", gender="Female", language="zh", description="Gentle and professional."),
|
||||
Voice(id="v2", name="Kevin", vendor="Volcano", gender="Male", language="en", description="Deep and authoritative."),
|
||||
Voice(id="v3", name="Abby", vendor="Minimax", gender="Female", language="en", description="Cheerful and lively."),
|
||||
Voice(id="v4", name="Guang", vendor="Ali", gender="Male", language="zh", description="Standard newscast style."),
|
||||
Voice(id="v5", name="Doubao", vendor="Volcano", gender="Female", language="zh", description="Cute and young."),
|
||||
]
|
||||
for v in voices:
|
||||
db.add(v)
|
||||
db.commit()
|
||||
print("✅ 默认声音数据已初始化")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from sqlalchemy import String, Integer, DateTime, Text, Float, ForeignKey, JSON
|
||||
from sqlalchemy import String, Integer, DateTime, Text, Float, ForeignKey, JSON, Enum
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from .db import Base
|
||||
@@ -15,18 +15,72 @@ class User(Base):
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
|
||||
|
||||
# ============ Voice ============
|
||||
class Voice(Base):
|
||||
__tablename__ = "voices"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
user_id: Mapped[Optional[int]] = mapped_column(Integer, ForeignKey("users.id"), index=True, nullable=True)
|
||||
name: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
vendor: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
gender: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
language: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
description: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
voice_params: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
model: Mapped[Optional[str]] = mapped_column(String(128), nullable=True) # 厂商语音模型标识
|
||||
voice_key: Mapped[Optional[str]] = mapped_column(String(128), nullable=True) # 厂商voice_key
|
||||
speed: Mapped[float] = mapped_column(Float, default=1.0)
|
||||
gain: Mapped[int] = mapped_column(Integer, default=0)
|
||||
pitch: Mapped[int] = mapped_column(Integer, default=0)
|
||||
enabled: Mapped[bool] = mapped_column(default=True)
|
||||
is_system: Mapped[bool] = mapped_column(default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
|
||||
user = relationship("User", foreign_keys=[user_id])
|
||||
|
||||
|
||||
# ============ LLM Model ============
|
||||
class LLMModel(Base):
|
||||
__tablename__ = "llm_models"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), index=True)
|
||||
name: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
vendor: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(32), nullable=False) # text/embedding/rerank
|
||||
base_url: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
api_key: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
model_name: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
temperature: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
context_length: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
enabled: Mapped[bool] = mapped_column(default=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
|
||||
user = relationship("User")
|
||||
|
||||
|
||||
# ============ ASR Model ============
|
||||
class ASRModel(Base):
|
||||
__tablename__ = "asr_models"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), index=True)
|
||||
name: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
vendor: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
language: Mapped[str] = mapped_column(String(32), nullable=False) # zh/en/Multi-lingual
|
||||
base_url: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
api_key: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
model_name: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
hotwords: Mapped[dict] = mapped_column(JSON, default=list)
|
||||
enable_punctuation: Mapped[bool] = mapped_column(default=True)
|
||||
enable_normalization: Mapped[bool] = mapped_column(default=True)
|
||||
enabled: Mapped[bool] = mapped_column(default=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
|
||||
user = relationship("User")
|
||||
|
||||
|
||||
# ============ Assistant ============
|
||||
class Assistant(Base):
|
||||
__tablename__ = "assistants"
|
||||
|
||||
@@ -46,6 +100,11 @@ class Assistant(Base):
|
||||
config_mode: Mapped[str] = mapped_column(String(32), default="platform")
|
||||
api_url: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
api_key: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
# 模型关联
|
||||
llm_model_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
asr_model_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
embedding_model_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
rerank_model_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
|
||||
@@ -53,6 +112,7 @@ class Assistant(Base):
|
||||
call_records = relationship("CallRecord", back_populates="assistant")
|
||||
|
||||
|
||||
# ============ Knowledge Base ============
|
||||
class KnowledgeBase(Base):
|
||||
__tablename__ = "knowledge_bases"
|
||||
|
||||
@@ -92,6 +152,7 @@ class KnowledgeDocument(Base):
|
||||
kb = relationship("KnowledgeBase", back_populates="documents")
|
||||
|
||||
|
||||
# ============ Workflow ============
|
||||
class Workflow(Base):
|
||||
__tablename__ = "workflows"
|
||||
|
||||
@@ -108,6 +169,7 @@ class Workflow(Base):
|
||||
user = relationship("User")
|
||||
|
||||
|
||||
# ============ Call Record ============
|
||||
class CallRecord(Base):
|
||||
__tablename__ = "call_records"
|
||||
|
||||
|
||||
@@ -1,24 +1,203 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# ============ Enums ============
|
||||
class AssistantConfigMode(str, Enum):
|
||||
PLATFORM = "platform"
|
||||
DIFY = "dify"
|
||||
FASTGPT = "fastgpt"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
class LLMModelType(str, Enum):
|
||||
TEXT = "text"
|
||||
EMBEDDING = "embedding"
|
||||
RERANK = "rerank"
|
||||
|
||||
|
||||
class ASRLanguage(str, Enum):
|
||||
ZH = "zh"
|
||||
EN = "en"
|
||||
MULTILINGUAL = "Multi-lingual"
|
||||
|
||||
|
||||
class VoiceGender(str, Enum):
|
||||
MALE = "Male"
|
||||
FEMALE = "Female"
|
||||
|
||||
|
||||
class CallRecordSource(str, Enum):
|
||||
DEBUG = "debug"
|
||||
EXTERNAL = "external"
|
||||
|
||||
|
||||
class CallRecordStatus(str, Enum):
|
||||
CONNECTED = "connected"
|
||||
MISSED = "missed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
# ============ Voice ============
|
||||
class VoiceBase(BaseModel):
|
||||
name: str
|
||||
vendor: str
|
||||
gender: str
|
||||
language: str
|
||||
description: str
|
||||
gender: str # "Male" | "Female"
|
||||
language: str # "zh" | "en"
|
||||
description: str = ""
|
||||
|
||||
|
||||
class VoiceCreate(VoiceBase):
|
||||
model: str # 厂商语音模型标识
|
||||
voice_key: str # 厂商voice_key
|
||||
speed: float = 1.0
|
||||
gain: int = 0
|
||||
pitch: int = 0
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class VoiceUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
voice_key: Optional[str] = None
|
||||
speed: Optional[float] = None
|
||||
gain: Optional[int] = None
|
||||
pitch: Optional[int] = None
|
||||
enabled: Optional[bool] = None
|
||||
|
||||
|
||||
class VoiceOut(VoiceBase):
|
||||
id: str
|
||||
user_id: Optional[int] = None
|
||||
model: Optional[str] = None
|
||||
voice_key: Optional[str] = None
|
||||
speed: float = 1.0
|
||||
gain: int = 0
|
||||
pitch: int = 0
|
||||
enabled: bool = True
|
||||
is_system: bool = False
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class VoicePreviewRequest(BaseModel):
|
||||
text: str
|
||||
speed: Optional[float] = None
|
||||
gain: Optional[int] = None
|
||||
pitch: Optional[int] = None
|
||||
|
||||
|
||||
class VoicePreviewResponse(BaseModel):
|
||||
success: bool
|
||||
audio_url: Optional[str] = None
|
||||
duration_ms: Optional[int] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
# ============ LLM Model ============
|
||||
class LLMModelBase(BaseModel):
|
||||
name: str
|
||||
vendor: str
|
||||
type: LLMModelType
|
||||
base_url: str
|
||||
api_key: str
|
||||
model_name: Optional[str] = None
|
||||
temperature: Optional[float] = None
|
||||
context_length: Optional[int] = None
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class LLMModelCreate(LLMModelBase):
|
||||
pass
|
||||
|
||||
|
||||
class LLMModelUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
model_name: Optional[str] = None
|
||||
temperature: Optional[float] = None
|
||||
context_length: Optional[int] = None
|
||||
enabled: Optional[bool] = None
|
||||
|
||||
|
||||
class LLMModelOut(LLMModelBase):
|
||||
id: str
|
||||
user_id: int
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class LLMModelTestResponse(BaseModel):
|
||||
success: bool
|
||||
latency_ms: Optional[int] = None
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
# ============ ASR Model ============
|
||||
class ASRModelBase(BaseModel):
|
||||
name: str
|
||||
vendor: str
|
||||
language: str # "zh" | "en" | "Multi-lingual"
|
||||
base_url: str
|
||||
api_key: str
|
||||
model_name: Optional[str] = None
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class ASRModelCreate(ASRModelBase):
|
||||
hotwords: List[str] = []
|
||||
enable_punctuation: bool = True
|
||||
enable_normalization: bool = True
|
||||
|
||||
|
||||
class ASRModelUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
model_name: Optional[str] = None
|
||||
hotwords: Optional[List[str]] = None
|
||||
enable_punctuation: Optional[bool] = None
|
||||
enable_normalization: Optional[bool] = None
|
||||
enabled: Optional[bool] = None
|
||||
|
||||
|
||||
class ASRModelOut(ASRModelBase):
|
||||
id: str
|
||||
user_id: int
|
||||
hotwords: List[str] = []
|
||||
enable_punctuation: bool = True
|
||||
enable_normalization: bool = True
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ASRTestRequest(BaseModel):
|
||||
audio_url: Optional[str] = None
|
||||
audio_data: Optional[str] = None # base64 encoded
|
||||
|
||||
|
||||
class ASRTestResponse(BaseModel):
|
||||
success: bool
|
||||
transcript: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
confidence: Optional[float] = None
|
||||
duration_ms: Optional[int] = None
|
||||
latency_ms: Optional[int] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
# ============ Assistant ============
|
||||
class AssistantBase(BaseModel):
|
||||
name: str
|
||||
@@ -34,25 +213,56 @@ class AssistantBase(BaseModel):
|
||||
configMode: str = "platform"
|
||||
apiUrl: Optional[str] = None
|
||||
apiKey: Optional[str] = None
|
||||
# 模型关联
|
||||
llmModelId: Optional[str] = None
|
||||
asrModelId: Optional[str] = None
|
||||
embeddingModelId: Optional[str] = None
|
||||
rerankModelId: Optional[str] = None
|
||||
|
||||
|
||||
class AssistantCreate(AssistantBase):
|
||||
pass
|
||||
|
||||
|
||||
class AssistantUpdate(AssistantBase):
|
||||
class AssistantUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
opener: Optional[str] = None
|
||||
prompt: Optional[str] = None
|
||||
knowledgeBaseId: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
voice: Optional[str] = None
|
||||
speed: Optional[float] = None
|
||||
hotwords: Optional[List[str]] = None
|
||||
tools: Optional[List[str]] = None
|
||||
interruptionSensitivity: Optional[int] = None
|
||||
configMode: Optional[str] = None
|
||||
apiUrl: Optional[str] = None
|
||||
apiKey: Optional[str] = None
|
||||
llmModelId: Optional[str] = None
|
||||
asrModelId: Optional[str] = None
|
||||
embeddingModelId: Optional[str] = None
|
||||
rerankModelId: Optional[str] = None
|
||||
|
||||
|
||||
class AssistantOut(AssistantBase):
|
||||
id: str
|
||||
callCount: int = 0
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class AssistantStats(BaseModel):
|
||||
assistant_id: str
|
||||
total_calls: int = 0
|
||||
connected_calls: int = 0
|
||||
missed_calls: int = 0
|
||||
avg_duration_seconds: float = 0.0
|
||||
today_calls: int = 0
|
||||
|
||||
|
||||
# ============ Knowledge Base ============
|
||||
class KnowledgeDocument(BaseModel):
|
||||
id: str
|
||||
@@ -196,6 +406,7 @@ class TranscriptSegment(BaseModel):
|
||||
endMs: int
|
||||
durationMs: Optional[int] = None
|
||||
audioUrl: Optional[str] = None
|
||||
emotion: Optional[str] = None
|
||||
|
||||
|
||||
class CallRecordCreate(BaseModel):
|
||||
@@ -208,6 +419,9 @@ class CallRecordUpdate(BaseModel):
|
||||
status: Optional[str] = None
|
||||
summary: Optional[str] = None
|
||||
duration_seconds: Optional[int] = None
|
||||
ended_at: Optional[str] = None
|
||||
cost: Optional[float] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
class CallRecordOut(BaseModel):
|
||||
@@ -220,6 +434,9 @@ class CallRecordOut(BaseModel):
|
||||
ended_at: Optional[str] = None
|
||||
duration_seconds: Optional[int] = None
|
||||
summary: Optional[str] = None
|
||||
cost: float = 0.0
|
||||
metadata: dict = {}
|
||||
created_at: Optional[datetime] = None
|
||||
transcripts: List[TranscriptSegment] = []
|
||||
|
||||
class Config:
|
||||
@@ -246,6 +463,19 @@ class TranscriptOut(TranscriptCreate):
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ============ History Stats ============
|
||||
class HistoryStats(BaseModel):
|
||||
total_calls: int = 0
|
||||
connected_calls: int = 0
|
||||
missed_calls: int = 0
|
||||
failed_calls: int = 0
|
||||
avg_duration_seconds: float = 0.0
|
||||
total_cost: float = 0.0
|
||||
by_status: dict = {}
|
||||
by_source: dict = {}
|
||||
daily_trend: List[dict] = []
|
||||
|
||||
|
||||
# ============ Dashboard ============
|
||||
class DashboardStats(BaseModel):
|
||||
totalCalls: int
|
||||
@@ -269,3 +499,9 @@ class ListResponse(BaseModel):
|
||||
page: int
|
||||
limit: int
|
||||
list: List
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
id: str
|
||||
started_at: str
|
||||
matched_content: Optional[str] = None
|
||||
|
||||
@@ -6,47 +6,26 @@ import sys
|
||||
# 添加路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from app.db import Base, engine
|
||||
from app.db import Base, engine, DATABASE_URL
|
||||
from app.models import Voice
|
||||
|
||||
|
||||
def init_db():
|
||||
"""创建所有表"""
|
||||
# 确保 data 目录存在
|
||||
data_dir = os.path.dirname(DATABASE_URL.replace("sqlite:///", ""))
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
print("📦 创建数据库表...")
|
||||
Base.metadata.drop_all(bind=engine) # 删除旧表
|
||||
Base.metadata.create_all(bind=engine)
|
||||
print("✅ 数据库表创建完成")
|
||||
|
||||
|
||||
def init_default_voices():
|
||||
"""初始化默认声音"""
|
||||
from app.db import SessionLocal
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
if db.query(Voice).count() == 0:
|
||||
voices = [
|
||||
Voice(id="v1", name="Xiaoyun", vendor="Ali", gender="Female", language="zh", description="Gentle and professional."),
|
||||
Voice(id="v2", name="Kevin", vendor="Volcano", gender="Male", language="en", description="Deep and authoritative."),
|
||||
Voice(id="v3", name="Abby", vendor="Minimax", gender="Female", language="en", description="Cheerful and lively."),
|
||||
Voice(id="v4", name="Guang", vendor="Ali", gender="Male", language="zh", description="Standard newscast style."),
|
||||
Voice(id="v5", name="Doubao", vendor="Volcano", gender="Female", language="zh", description="Cute and young."),
|
||||
]
|
||||
for v in voices:
|
||||
db.add(v)
|
||||
db.commit()
|
||||
print("✅ 默认声音数据已初始化")
|
||||
else:
|
||||
print("ℹ️ 声音数据已存在,跳过初始化")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 确保 data 目录存在
|
||||
data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data")
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
init_db()
|
||||
init_default_voices()
|
||||
print("🎉 数据库初始化完成!")
|
||||
|
||||
54
api/main.py
54
api/main.py
@@ -6,6 +6,9 @@ import os
|
||||
from app.db import Base, engine
|
||||
from app.routers import assistants, history, knowledge
|
||||
|
||||
# 配置
|
||||
PORT = int(os.getenv("PORT", 8100))
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
@@ -57,17 +60,54 @@ def init_default_data():
|
||||
try:
|
||||
# 检查是否已有数据
|
||||
if db.query(Voice).count() == 0:
|
||||
# 插入默认声音
|
||||
# SiliconFlow CosyVoice 2.0 预设声音 (8个)
|
||||
# 参考: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech
|
||||
voices = [
|
||||
Voice(id="v1", name="Xiaoyun", vendor="Ali", gender="Female", language="zh", description="Gentle and professional."),
|
||||
Voice(id="v2", name="Kevin", vendor="Volcano", gender="Male", language="en", description="Deep and authoritative."),
|
||||
Voice(id="v3", name="Abby", vendor="Minimax", gender="Female", language="en", description="Cheerful and lively."),
|
||||
Voice(id="v4", name="Guang", vendor="Ali", gender="Male", language="zh", description="Standard newscast style."),
|
||||
Voice(id="v5", name="Doubao", vendor="Volcano", gender="Female", language="zh", description="Cute and young."),
|
||||
# 男声 (Male Voices)
|
||||
Voice(id="alex", name="Alex", vendor="SiliconFlow", gender="Male", language="en",
|
||||
description="Steady male voice.", is_system=True),
|
||||
Voice(id="benjamin", name="Benjamin", vendor="SiliconFlow", gender="Male", language="en",
|
||||
description="Deep male voice.", is_system=True),
|
||||
Voice(id="charles", name="Charles", vendor="SiliconFlow", gender="Male", language="en",
|
||||
description="Magnetic male voice.", is_system=True),
|
||||
Voice(id="david", name="David", vendor="SiliconFlow", gender="Male", language="en",
|
||||
description="Cheerful male voice.", is_system=True),
|
||||
# 女声 (Female Voices)
|
||||
Voice(id="anna", name="Anna", vendor="SiliconFlow", gender="Female", language="en",
|
||||
description="Steady female voice.", is_system=True),
|
||||
Voice(id="bella", name="Bella", vendor="SiliconFlow", gender="Female", language="en",
|
||||
description="Passionate female voice.", is_system=True),
|
||||
Voice(id="claire", name="Claire", vendor="SiliconFlow", gender="Female", language="en",
|
||||
description="Gentle female voice.", is_system=True),
|
||||
Voice(id="diana", name="Diana", vendor="SiliconFlow", gender="Female", language="en",
|
||||
description="Cheerful female voice.", is_system=True),
|
||||
# 中文方言 (Chinese Dialects) - 可选扩展
|
||||
Voice(id="amador", name="Amador", vendor="SiliconFlow", gender="Male", language="zh",
|
||||
description="Male voice with Spanish accent."),
|
||||
Voice(id="aelora", name="Aelora", vendor="SiliconFlow", gender="Female", language="en",
|
||||
description="Elegant female voice."),
|
||||
Voice(id="aelwin", name="Aelwin", vendor="SiliconFlow", gender="Male", language="en",
|
||||
description="Deep male voice."),
|
||||
Voice(id="blooming", name="Blooming", vendor="SiliconFlow", gender="Female", language="en",
|
||||
description="Fresh and clear female voice."),
|
||||
Voice(id="elysia", name="Elysia", vendor="SiliconFlow", gender="Female", language="en",
|
||||
description="Smooth and silky female voice."),
|
||||
Voice(id="leo", name="Leo", vendor="SiliconFlow", gender="Male", language="en",
|
||||
description="Young male voice."),
|
||||
Voice(id="lin", name="Lin", vendor="SiliconFlow", gender="Female", language="zh",
|
||||
description="Standard Chinese female voice."),
|
||||
Voice(id="rose", name="Rose", vendor="SiliconFlow", gender="Female", language="en",
|
||||
description="Soft and gentle female voice."),
|
||||
Voice(id="shao", name="Shao", vendor="SiliconFlow", gender="Male", language="zh",
|
||||
description="Deep Chinese male voice."),
|
||||
Voice(id="sky", name="Sky", vendor="SiliconFlow", gender="Male", language="en",
|
||||
description="Clear and bright male voice."),
|
||||
Voice(id="ael西山", name="Ael西山", vendor="SiliconFlow", gender="Female", language="zh",
|
||||
description="Female voice with Chinese dialect."),
|
||||
]
|
||||
for v in voices:
|
||||
db.add(v)
|
||||
db.commit()
|
||||
print("✅ 默认声音数据已初始化")
|
||||
print("✅ 默认声音数据已初始化 (SiliconFlow CosyVoice 2.0)")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
8
api/pytest.ini
Normal file
8
api/pytest.ini
Normal file
@@ -0,0 +1,8 @@
|
||||
[pytest]
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
addopts = -v --tb=short
|
||||
filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
14
api/run_tests.bat
Normal file
14
api/run_tests.bat
Normal file
@@ -0,0 +1,14 @@
|
||||
@echo off
|
||||
REM Run API tests
|
||||
|
||||
cd /d "%~dp0"
|
||||
|
||||
REM Install test dependencies
|
||||
echo Installing test dependencies...
|
||||
pip install pytest pytest-cov -q
|
||||
|
||||
REM Run tests
|
||||
echo Running tests...
|
||||
pytest tests/ -v --tb=short
|
||||
|
||||
pause
|
||||
1
api/tests/__init__.py
Normal file
1
api/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Tests package
|
||||
102
api/tests/conftest.py
Normal file
102
api/tests/conftest.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Pytest fixtures for API tests"""
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
|
||||
# Add api directory to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.db import Base, get_db
|
||||
from app.main import app
|
||||
|
||||
|
||||
# Use in-memory SQLite for testing
|
||||
DATABASE_URL = "sqlite:///:memory:"
|
||||
|
||||
engine = create_engine(
|
||||
DATABASE_URL,
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def db_session():
|
||||
"""Create a fresh database session for each test"""
|
||||
# Create all tables
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
session = TestingSessionLocal()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
# Drop all tables after test
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client(db_session):
|
||||
"""Create a test client with database dependency override"""
|
||||
|
||||
def override_get_db():
|
||||
try:
|
||||
yield db_session
|
||||
finally:
|
||||
pass
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
with TestClient(app) as test_client:
|
||||
yield test_client
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_voice_data():
|
||||
"""Sample voice data for testing"""
|
||||
return {
|
||||
"name": "Test Voice",
|
||||
"vendor": "TestVendor",
|
||||
"gender": "Female",
|
||||
"language": "zh",
|
||||
"description": "A test voice for unit testing",
|
||||
"model": "test-model",
|
||||
"voice_key": "test-key",
|
||||
"speed": 1.0,
|
||||
"gain": 0,
|
||||
"pitch": 0,
|
||||
"enabled": True
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_assistant_data():
|
||||
"""Sample assistant data for testing"""
|
||||
return {
|
||||
"name": "Test Assistant",
|
||||
"opener": "Hello, welcome!",
|
||||
"prompt": "You are a helpful assistant.",
|
||||
"language": "zh",
|
||||
"speed": 1.0,
|
||||
"hotwords": ["test", "hello"],
|
||||
"tools": [],
|
||||
"configMode": "platform"
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_call_record_data():
|
||||
"""Sample call record data for testing"""
|
||||
return {
|
||||
"user_id": 1,
|
||||
"assistant_id": None,
|
||||
"source": "debug"
|
||||
}
|
||||
168
api/tests/test_assistants.py
Normal file
168
api/tests/test_assistants.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Tests for Assistant API endpoints"""
|
||||
import pytest
|
||||
import uuid
|
||||
|
||||
|
||||
class TestAssistantAPI:
|
||||
"""Test cases for Assistant endpoints"""
|
||||
|
||||
def test_get_assistants_empty(self, client):
|
||||
"""Test getting assistants when database is empty"""
|
||||
response = client.get("/api/assistants")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "total" in data
|
||||
assert "list" in data
|
||||
|
||||
def test_create_assistant(self, client, sample_assistant_data):
|
||||
"""Test creating a new assistant"""
|
||||
response = client.post("/api/assistants", json=sample_assistant_data)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == sample_assistant_data["name"]
|
||||
assert data["opener"] == sample_assistant_data["opener"]
|
||||
assert data["prompt"] == sample_assistant_data["prompt"]
|
||||
assert data["language"] == sample_assistant_data["language"]
|
||||
assert "id" in data
|
||||
assert data["callCount"] == 0
|
||||
|
||||
def test_create_assistant_minimal(self, client):
|
||||
"""Test creating an assistant with minimal required data"""
|
||||
data = {"name": "Minimal Assistant"}
|
||||
response = client.post("/api/assistants", json=data)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "Minimal Assistant"
|
||||
|
||||
def test_get_assistant_by_id(self, client, sample_assistant_data):
|
||||
"""Test getting a specific assistant by ID"""
|
||||
# Create first
|
||||
create_response = client.post("/api/assistants", json=sample_assistant_data)
|
||||
assistant_id = create_response.json()["id"]
|
||||
|
||||
# Get by ID
|
||||
response = client.get(f"/api/assistants/{assistant_id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == assistant_id
|
||||
assert data["name"] == sample_assistant_data["name"]
|
||||
|
||||
def test_get_assistant_not_found(self, client):
|
||||
"""Test getting a non-existent assistant"""
|
||||
response = client.get("/api/assistants/non-existent-id")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_update_assistant(self, client, sample_assistant_data):
|
||||
"""Test updating an assistant"""
|
||||
# Create first
|
||||
create_response = client.post("/api/assistants", json=sample_assistant_data)
|
||||
assistant_id = create_response.json()["id"]
|
||||
|
||||
# Update
|
||||
update_data = {
|
||||
"name": "Updated Assistant",
|
||||
"prompt": "You are an updated assistant.",
|
||||
"speed": 1.5
|
||||
}
|
||||
response = client.put(f"/api/assistants/{assistant_id}", json=update_data)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "Updated Assistant"
|
||||
assert data["prompt"] == "You are an updated assistant."
|
||||
assert data["speed"] == 1.5
|
||||
|
||||
def test_delete_assistant(self, client, sample_assistant_data):
|
||||
"""Test deleting an assistant"""
|
||||
# Create first
|
||||
create_response = client.post("/api/assistants", json=sample_assistant_data)
|
||||
assistant_id = create_response.json()["id"]
|
||||
|
||||
# Delete
|
||||
response = client.delete(f"/api/assistants/{assistant_id}")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify deleted
|
||||
get_response = client.get(f"/api/assistants/{assistant_id}")
|
||||
assert get_response.status_code == 404
|
||||
|
||||
def test_list_assistants_with_pagination(self, client, sample_assistant_data):
|
||||
"""Test listing assistants with pagination"""
|
||||
# Create multiple assistants
|
||||
for i in range(3):
|
||||
data = sample_assistant_data.copy()
|
||||
data["name"] = f"Assistant {i}"
|
||||
client.post("/api/assistants", json=data)
|
||||
|
||||
# Test pagination
|
||||
response = client.get("/api/assistants?page=1&limit=2")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 3
|
||||
assert len(data["list"]) == 2
|
||||
|
||||
def test_create_assistant_with_voice(self, client, sample_assistant_data, sample_voice_data):
|
||||
"""Test creating an assistant with a voice reference"""
|
||||
# Create a voice first
|
||||
voice_response = client.post("/api/voices", json=sample_voice_data)
|
||||
voice_id = voice_response.json()["id"]
|
||||
|
||||
# Create assistant with voice
|
||||
sample_assistant_data["voice"] = voice_id
|
||||
response = client.post("/api/assistants", json=sample_assistant_data)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["voice"] == voice_id
|
||||
|
||||
def test_create_assistant_with_knowledge_base(self, client, sample_assistant_data):
|
||||
"""Test creating an assistant with knowledge base reference"""
|
||||
# Note: This test assumes knowledge base doesn't exist
|
||||
sample_assistant_data["knowledgeBaseId"] = "non-existent-kb"
|
||||
response = client.post("/api/assistants", json=sample_assistant_data)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["knowledgeBaseId"] == "non-existent-kb"
|
||||
|
||||
def test_assistant_with_model_references(self, client, sample_assistant_data):
|
||||
"""Test creating assistant with model references"""
|
||||
sample_assistant_data.update({
|
||||
"llmModelId": "llm-001",
|
||||
"asrModelId": "asr-001",
|
||||
"embeddingModelId": "emb-001",
|
||||
"rerankModelId": "rerank-001"
|
||||
})
|
||||
response = client.post("/api/assistants", json=sample_assistant_data)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["llmModelId"] == "llm-001"
|
||||
assert data["asrModelId"] == "asr-001"
|
||||
assert data["embeddingModelId"] == "emb-001"
|
||||
assert data["rerankModelId"] == "rerank-001"
|
||||
|
||||
def test_assistant_with_tools(self, client, sample_assistant_data):
|
||||
"""Test creating assistant with tools"""
|
||||
sample_assistant_data["tools"] = ["weather", "calculator", "search"]
|
||||
response = client.post("/api/assistants", json=sample_assistant_data)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["tools"] == ["weather", "calculator", "search"]
|
||||
|
||||
def test_assistant_with_hotwords(self, client, sample_assistant_data):
|
||||
"""Test creating assistant with hotwords"""
|
||||
sample_assistant_data["hotwords"] = ["hello", "help", "stop"]
|
||||
response = client.post("/api/assistants", json=sample_assistant_data)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["hotwords"] == ["hello", "help", "stop"]
|
||||
|
||||
def test_different_config_modes(self, client, sample_assistant_data):
|
||||
"""Test creating assistants with different config modes"""
|
||||
for mode in ["platform", "dify", "fastgpt", "none"]:
|
||||
sample_assistant_data["name"] = f"Assistant {mode}"
|
||||
sample_assistant_data["configMode"] = mode
|
||||
response = client.post("/api/assistants", json=sample_assistant_data)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["configMode"] == mode
|
||||
|
||||
def test_different_languages(self, client, sample_assistant_data):
|
||||
"""Test creating assistants with different languages"""
|
||||
for lang in ["zh", "en", "ja", "ko"]:
|
||||
sample_assistant_data["name"] = f"Assistant {lang}"
|
||||
sample_assistant_data["language"] = lang
|
||||
response = client.post("/api/assistants", json=sample_assistant_data)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["language"] == lang
|
||||
236
api/tests/test_history.py
Normal file
236
api/tests/test_history.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""Tests for History/Call Record API endpoints"""
|
||||
import pytest
|
||||
import time
|
||||
|
||||
|
||||
class TestHistoryAPI:
|
||||
"""Test cases for History/Call Record endpoints"""
|
||||
|
||||
def test_get_history_empty(self, client):
|
||||
"""Test getting history when database is empty"""
|
||||
response = client.get("/api/history")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "total" in data
|
||||
assert "list" in data
|
||||
|
||||
def test_create_call_record(self, client, sample_call_record_data):
|
||||
"""Test creating a new call record"""
|
||||
response = client.post("/api/history", json=sample_call_record_data)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == sample_call_record_data["user_id"]
|
||||
assert data["source"] == sample_call_record_data["source"]
|
||||
assert data["status"] == "connected"
|
||||
assert "id" in data
|
||||
assert "started_at" in data
|
||||
|
||||
def test_create_call_record_with_assistant(self, client, sample_assistant_data, sample_call_record_data):
|
||||
"""Test creating a call record associated with an assistant"""
|
||||
# Create assistant first
|
||||
assistant_response = client.post("/api/assistants", json=sample_assistant_data)
|
||||
assistant_id = assistant_response.json()["id"]
|
||||
|
||||
# Create call record with assistant
|
||||
sample_call_record_data["assistant_id"] = assistant_id
|
||||
response = client.post("/api/history", json=sample_call_record_data)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["assistant_id"] == assistant_id
|
||||
|
||||
def test_get_call_record_by_id(self, client, sample_call_record_data):
|
||||
"""Test getting a specific call record by ID"""
|
||||
# Create first
|
||||
create_response = client.post("/api/history", json=sample_call_record_data)
|
||||
record_id = create_response.json()["id"]
|
||||
|
||||
# Get by ID
|
||||
response = client.get(f"/api/history/{record_id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == record_id
|
||||
|
||||
def test_get_call_record_not_found(self, client):
|
||||
"""Test getting a non-existent call record"""
|
||||
response = client.get("/api/history/non-existent-id")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_update_call_record(self, client, sample_call_record_data):
|
||||
"""Test updating a call record"""
|
||||
# Create first
|
||||
create_response = client.post("/api/history", json=sample_call_record_data)
|
||||
record_id = create_response.json()["id"]
|
||||
|
||||
# Update
|
||||
update_data = {
|
||||
"status": "completed",
|
||||
"summary": "Test summary",
|
||||
"duration_seconds": 120
|
||||
}
|
||||
response = client.put(f"/api/history/{record_id}", json=update_data)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "completed"
|
||||
assert data["summary"] == "Test summary"
|
||||
assert data["duration_seconds"] == 120
|
||||
|
||||
def test_delete_call_record(self, client, sample_call_record_data):
|
||||
"""Test deleting a call record"""
|
||||
# Create first
|
||||
create_response = client.post("/api/history", json=sample_call_record_data)
|
||||
record_id = create_response.json()["id"]
|
||||
|
||||
# Delete
|
||||
response = client.delete(f"/api/history/{record_id}")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify deleted
|
||||
get_response = client.get(f"/api/history/{record_id}")
|
||||
assert get_response.status_code == 404
|
||||
|
||||
def test_add_transcript(self, client, sample_call_record_data):
|
||||
"""Test adding a transcript to a call record"""
|
||||
# Create call record first
|
||||
create_response = client.post("/api/history", json=sample_call_record_data)
|
||||
record_id = create_response.json()["id"]
|
||||
|
||||
# Add transcript
|
||||
transcript_data = {
|
||||
"turn_index": 0,
|
||||
"speaker": "human",
|
||||
"content": "Hello, I need help",
|
||||
"start_ms": 0,
|
||||
"end_ms": 3000,
|
||||
"confidence": 0.95
|
||||
}
|
||||
response = client.post(
|
||||
f"/api/history/{record_id}/transcripts",
|
||||
json=transcript_data
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["turn_index"] == 0
|
||||
assert data["speaker"] == "human"
|
||||
assert data["content"] == "Hello, I need help"
|
||||
|
||||
def test_add_multiple_transcripts(self, client, sample_call_record_data):
|
||||
"""Test adding multiple transcripts"""
|
||||
# Create call record first
|
||||
create_response = client.post("/api/history", json=sample_call_record_data)
|
||||
record_id = create_response.json()["id"]
|
||||
|
||||
# Add human transcript
|
||||
human_transcript = {
|
||||
"turn_index": 0,
|
||||
"speaker": "human",
|
||||
"content": "Hello",
|
||||
"start_ms": 0,
|
||||
"end_ms": 1000
|
||||
}
|
||||
client.post(f"/api/history/{record_id}/transcripts", json=human_transcript)
|
||||
|
||||
# Add AI transcript
|
||||
ai_transcript = {
|
||||
"turn_index": 1,
|
||||
"speaker": "ai",
|
||||
"content": "Hello! How can I help you?",
|
||||
"start_ms": 1500,
|
||||
"end_ms": 4000
|
||||
}
|
||||
client.post(f"/api/history/{record_id}/transcripts", json=ai_transcript)
|
||||
|
||||
# Verify both transcripts exist
|
||||
response = client.get(f"/api/history/{record_id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["transcripts"]) == 2
|
||||
|
||||
def test_filter_history_by_status(self, client, sample_call_record_data):
|
||||
"""Test filtering history by status"""
|
||||
# Create records with different statuses
|
||||
for i in range(2):
|
||||
data = sample_call_record_data.copy()
|
||||
data["status"] = "connected" if i % 2 == 0 else "missed"
|
||||
client.post("/api/history", json=data)
|
||||
|
||||
# Filter by status
|
||||
response = client.get("/api/history?status=connected")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
for record in data["list"]:
|
||||
assert record["status"] == "connected"
|
||||
|
||||
def test_filter_history_by_source(self, client, sample_call_record_data):
|
||||
"""Test filtering history by source"""
|
||||
sample_call_record_data["source"] = "external"
|
||||
client.post("/api/history", json=sample_call_record_data)
|
||||
|
||||
response = client.get("/api/history?source=external")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
for record in data["list"]:
|
||||
assert record["source"] == "external"
|
||||
|
||||
def test_history_pagination(self, client, sample_call_record_data):
|
||||
"""Test history pagination"""
|
||||
# Create multiple records
|
||||
for i in range(5):
|
||||
data = sample_call_record_data.copy()
|
||||
data["source"] = f"source-{i}"
|
||||
client.post("/api/history", json=data)
|
||||
|
||||
# Test pagination
|
||||
response = client.get("/api/history?page=1&limit=3")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 5
|
||||
assert len(data["list"]) == 3
|
||||
|
||||
def test_transcript_with_emotion(self, client, sample_call_record_data):
|
||||
"""Test adding transcript with emotion"""
|
||||
# Create call record first
|
||||
create_response = client.post("/api/history", json=sample_call_record_data)
|
||||
record_id = create_response.json()["id"]
|
||||
|
||||
# Add transcript with emotion
|
||||
transcript_data = {
|
||||
"turn_index": 0,
|
||||
"speaker": "ai",
|
||||
"content": "Great news!",
|
||||
"start_ms": 0,
|
||||
"end_ms": 2000,
|
||||
"emotion": "happy"
|
||||
}
|
||||
response = client.post(
|
||||
f"/api/history/{record_id}/transcripts",
|
||||
json=transcript_data
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["emotion"] == "happy"
|
||||
|
||||
def test_history_with_cost(self, client, sample_call_record_data):
|
||||
"""Test creating history with cost"""
|
||||
sample_call_record_data["cost"] = 0.05
|
||||
response = client.post("/api/history", json=sample_call_record_data)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["cost"] == 0.05
|
||||
|
||||
def test_history_search(self, client, sample_call_record_data):
|
||||
"""Test searching history"""
|
||||
# Create record
|
||||
create_response = client.post("/api/history", json=sample_call_record_data)
|
||||
record_id = create_response.json()["id"]
|
||||
|
||||
# Add transcript with searchable content
|
||||
transcript_data = {
|
||||
"turn_index": 0,
|
||||
"speaker": "human",
|
||||
"content": "I want to buy a product",
|
||||
"start_ms": 0,
|
||||
"end_ms": 3000
|
||||
}
|
||||
client.post(f"/api/history/{record_id}/transcripts", json=transcript_data)
|
||||
|
||||
# Search (this endpoint may not exist yet)
|
||||
response = client.get("/api/history/search?q=product")
|
||||
# This might return 404 if endpoint doesn't exist
|
||||
assert response.status_code in [200, 404]
|
||||
255
api/tests/test_knowledge.py
Normal file
255
api/tests/test_knowledge.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""Tests for Knowledge Base API endpoints"""
|
||||
import pytest
|
||||
import uuid
|
||||
|
||||
|
||||
class TestKnowledgeAPI:
|
||||
"""Test cases for Knowledge Base endpoints"""
|
||||
|
||||
def test_get_knowledge_bases_empty(self, client):
|
||||
"""Test getting knowledge bases when database is empty"""
|
||||
response = client.get("/api/knowledge/bases")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "total" in data
|
||||
assert "list" in data
|
||||
|
||||
def test_create_knowledge_base(self, client):
|
||||
"""Test creating a new knowledge base"""
|
||||
data = {
|
||||
"name": "Test Knowledge Base",
|
||||
"description": "A test knowledge base",
|
||||
"embeddingModel": "text-embedding-3-small",
|
||||
"chunkSize": 500,
|
||||
"chunkOverlap": 50
|
||||
}
|
||||
response = client.post("/api/knowledge/bases", json=data)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "Test Knowledge Base"
|
||||
assert data["description"] == "A test knowledge base"
|
||||
assert data["embeddingModel"] == "text-embedding-3-small"
|
||||
assert "id" in data
|
||||
assert data["docCount"] == 0
|
||||
assert data["chunkCount"] == 0
|
||||
assert data["status"] == "active"
|
||||
|
||||
def test_create_knowledge_base_minimal(self, client):
|
||||
"""Test creating a knowledge base with minimal data"""
|
||||
data = {"name": "Minimal KB"}
|
||||
response = client.post("/api/knowledge/bases", json=data)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "Minimal KB"
|
||||
|
||||
def test_get_knowledge_base_by_id(self, client):
|
||||
"""Test getting a specific knowledge base by ID"""
|
||||
# Create first
|
||||
create_data = {"name": "Test KB"}
|
||||
create_response = client.post("/api/knowledge/bases", json=create_data)
|
||||
kb_id = create_response.json()["id"]
|
||||
|
||||
# Get by ID
|
||||
response = client.get(f"/api/knowledge/bases/{kb_id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == kb_id
|
||||
assert data["name"] == "Test KB"
|
||||
|
||||
def test_get_knowledge_base_not_found(self, client):
|
||||
"""Test getting a non-existent knowledge base"""
|
||||
response = client.get("/api/knowledge/bases/non-existent-id")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_update_knowledge_base(self, client):
|
||||
"""Test updating a knowledge base"""
|
||||
# Create first
|
||||
create_data = {"name": "Original Name"}
|
||||
create_response = client.post("/api/knowledge/bases", json=create_data)
|
||||
kb_id = create_response.json()["id"]
|
||||
|
||||
# Update
|
||||
update_data = {
|
||||
"name": "Updated Name",
|
||||
"description": "Updated description",
|
||||
"chunkSize": 800
|
||||
}
|
||||
response = client.put(f"/api/knowledge/bases/{kb_id}", json=update_data)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "Updated Name"
|
||||
assert data["description"] == "Updated description"
|
||||
assert data["chunkSize"] == 800
|
||||
|
||||
def test_delete_knowledge_base(self, client):
|
||||
"""Test deleting a knowledge base"""
|
||||
# Create first
|
||||
create_data = {"name": "To Delete"}
|
||||
create_response = client.post("/api/knowledge/bases", json=create_data)
|
||||
kb_id = create_response.json()["id"]
|
||||
|
||||
# Delete
|
||||
response = client.delete(f"/api/knowledge/bases/{kb_id}")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify deleted
|
||||
get_response = client.get(f"/api/knowledge/bases/{kb_id}")
|
||||
assert get_response.status_code == 404
|
||||
|
||||
def test_upload_document(self, client):
|
||||
"""Test uploading a document to knowledge base"""
|
||||
# Create KB first
|
||||
create_data = {"name": "Test KB for Docs"}
|
||||
create_response = client.post("/api/knowledge/bases", json=create_data)
|
||||
kb_id = create_response.json()["id"]
|
||||
|
||||
# Upload document
|
||||
doc_data = {
|
||||
"name": "test-document.txt",
|
||||
"size": "1024",
|
||||
"fileType": "txt",
|
||||
"storageUrl": "https://storage.example.com/test-document.txt"
|
||||
}
|
||||
response = client.post(
|
||||
f"/api/knowledge/bases/{kb_id}/documents",
|
||||
json=doc_data
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "test-document.txt"
|
||||
assert "id" in data
|
||||
assert data["status"] == "pending"
|
||||
|
||||
def test_delete_document(self, client):
|
||||
"""Test deleting a document from knowledge base"""
|
||||
# Create KB first
|
||||
create_data = {"name": "Test KB for Delete"}
|
||||
create_response = client.post("/api/knowledge/bases", json=create_data)
|
||||
kb_id = create_response.json()["id"]
|
||||
|
||||
# Upload document
|
||||
doc_data = {"name": "to-delete.txt", "size": "100", "fileType": "txt"}
|
||||
upload_response = client.post(
|
||||
f"/api/knowledge/bases/{kb_id}/documents",
|
||||
json=doc_data
|
||||
)
|
||||
doc_id = upload_response.json()["id"]
|
||||
|
||||
# Delete document
|
||||
response = client.delete(
|
||||
f"/api/knowledge/bases/{kb_id}/documents/{doc_id}"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_index_document(self, client):
|
||||
"""Test indexing a document"""
|
||||
# Create KB first
|
||||
create_data = {"name": "Test KB for Index"}
|
||||
create_response = client.post("/api/knowledge/bases", json=create_data)
|
||||
kb_id = create_response.json()["id"]
|
||||
|
||||
# Index document
|
||||
index_data = {
|
||||
"document_id": "doc-001",
|
||||
"content": "This is the content to index. It contains important information about the product."
|
||||
}
|
||||
response = client.post(
|
||||
f"/api/knowledge/bases/{kb_id}/documents/doc-001/index",
|
||||
json=index_data
|
||||
)
|
||||
# This might return 200 or error depending on vector store implementation
|
||||
assert response.status_code in [200, 500]
|
||||
|
||||
def test_search_knowledge(self, client):
|
||||
"""Test searching knowledge base"""
|
||||
# Create KB first
|
||||
create_data = {"name": "Test KB for Search"}
|
||||
create_response = client.post("/api/knowledge/bases", json=create_data)
|
||||
kb_id = create_response.json()["id"]
|
||||
|
||||
# Search (this may fail without indexed content)
|
||||
search_data = {
|
||||
"query": "test query",
|
||||
"kb_id": kb_id,
|
||||
"nResults": 5
|
||||
}
|
||||
response = client.post("/api/knowledge/search", json=search_data)
|
||||
# This might return 200 or error depending on implementation
|
||||
assert response.status_code in [200, 500]
|
||||
|
||||
def test_get_knowledge_stats(self, client):
|
||||
"""Test getting knowledge base statistics"""
|
||||
# Create KB first
|
||||
create_data = {"name": "Test KB for Stats"}
|
||||
create_response = client.post("/api/knowledge/bases", json=create_data)
|
||||
kb_id = create_response.json()["id"]
|
||||
|
||||
response = client.get(f"/api/knowledge/bases/{kb_id}/stats")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["kb_id"] == kb_id
|
||||
assert "docCount" in data
|
||||
assert "chunkCount" in data
|
||||
|
||||
def test_knowledge_bases_pagination(self, client):
|
||||
"""Test knowledge bases pagination"""
|
||||
# Create multiple KBs
|
||||
for i in range(5):
|
||||
data = {"name": f"Knowledge Base {i}"}
|
||||
client.post("/api/knowledge/bases", json=data)
|
||||
|
||||
# Test pagination
|
||||
response = client.get("/api/knowledge/bases?page=1&limit=3")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 5
|
||||
assert len(data["list"]) == 3
|
||||
|
||||
def test_different_embedding_models(self, client):
|
||||
"""Test creating KB with different embedding models"""
|
||||
models = [
|
||||
"text-embedding-3-small",
|
||||
"text-embedding-3-large",
|
||||
"bge-small-zh"
|
||||
]
|
||||
for model in models:
|
||||
data = {"name": f"KB with {model}", "embeddingModel": model}
|
||||
response = client.post("/api/knowledge/bases", json=data)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["embeddingModel"] == model
|
||||
|
||||
def test_different_chunk_sizes(self, client):
|
||||
"""Test creating KB with different chunk configurations"""
|
||||
configs = [
|
||||
{"chunkSize": 500, "chunkOverlap": 50},
|
||||
{"chunkSize": 1000, "chunkOverlap": 100},
|
||||
{"chunkSize": 256, "chunkOverlap": 25}
|
||||
]
|
||||
for config in configs:
|
||||
data = {"name": "Chunk Test KB", **config}
|
||||
response = client.post("/api/knowledge/bases", json=data)
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_knowledge_base_with_documents(self, client):
|
||||
"""Test creating KB and adding multiple documents"""
|
||||
# Create KB
|
||||
create_data = {"name": "KB with Multiple Docs"}
|
||||
create_response = client.post("/api/knowledge/bases", json=create_data)
|
||||
kb_id = create_response.json()["id"]
|
||||
|
||||
# Add multiple documents
|
||||
for i in range(3):
|
||||
doc_data = {
|
||||
"name": f"document-{i}.txt",
|
||||
"size": f"{1000 + i * 100}",
|
||||
"fileType": "txt"
|
||||
}
|
||||
client.post(
|
||||
f"/api/knowledge/bases/{kb_id}/documents",
|
||||
json=doc_data
|
||||
)
|
||||
|
||||
# Verify documents are listed
|
||||
response = client.get(f"/api/knowledge/bases/{kb_id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["documents"]) == 3
|
||||
132
api/tests/test_voices.py
Normal file
132
api/tests/test_voices.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Tests for Voice API endpoints"""
|
||||
import pytest
|
||||
|
||||
|
||||
class TestVoiceAPI:
|
||||
"""Test cases for Voice endpoints"""
|
||||
|
||||
def test_get_voices_empty(self, client):
|
||||
"""Test getting voices when database is empty"""
|
||||
response = client.get("/api/voices")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "total" in data
|
||||
assert "list" in data
|
||||
|
||||
def test_create_voice(self, client, sample_voice_data):
|
||||
"""Test creating a new voice"""
|
||||
response = client.post("/api/voices", json=sample_voice_data)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == sample_voice_data["name"]
|
||||
assert data["vendor"] == sample_voice_data["vendor"]
|
||||
assert data["gender"] == sample_voice_data["gender"]
|
||||
assert data["language"] == sample_voice_data["language"]
|
||||
assert "id" in data
|
||||
|
||||
def test_create_voice_minimal(self, client):
|
||||
"""Test creating a voice with minimal data"""
|
||||
data = {
|
||||
"name": "Minimal Voice",
|
||||
"vendor": "Test",
|
||||
"gender": "Male",
|
||||
"language": "en",
|
||||
"description": ""
|
||||
}
|
||||
response = client.post("/api/voices", json=data)
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_get_voice_by_id(self, client, sample_voice_data):
|
||||
"""Test getting a specific voice by ID"""
|
||||
# Create first
|
||||
create_response = client.post("/api/voices", json=sample_voice_data)
|
||||
voice_id = create_response.json()["id"]
|
||||
|
||||
# Get by ID
|
||||
response = client.get(f"/api/voices/{voice_id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == voice_id
|
||||
assert data["name"] == sample_voice_data["name"]
|
||||
|
||||
def test_get_voice_not_found(self, client):
|
||||
"""Test getting a non-existent voice"""
|
||||
response = client.get("/api/voices/non-existent-id")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_update_voice(self, client, sample_voice_data):
|
||||
"""Test updating a voice"""
|
||||
# Create first
|
||||
create_response = client.post("/api/voices", json=sample_voice_data)
|
||||
voice_id = create_response.json()["id"]
|
||||
|
||||
# Update
|
||||
update_data = {"name": "Updated Voice", "speed": 1.5}
|
||||
response = client.put(f"/api/voices/{voice_id}", json=update_data)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "Updated Voice"
|
||||
assert data["speed"] == 1.5
|
||||
|
||||
def test_delete_voice(self, client, sample_voice_data):
|
||||
"""Test deleting a voice"""
|
||||
# Create first
|
||||
create_response = client.post("/api/voices", json=sample_voice_data)
|
||||
voice_id = create_response.json()["id"]
|
||||
|
||||
# Delete
|
||||
response = client.delete(f"/api/voices/{voice_id}")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify deleted
|
||||
get_response = client.get(f"/api/voices/{voice_id}")
|
||||
assert get_response.status_code == 404
|
||||
|
||||
def test_list_voices_with_pagination(self, client, sample_voice_data):
|
||||
"""Test listing voices with pagination"""
|
||||
# Create multiple voices
|
||||
for i in range(3):
|
||||
data = sample_voice_data.copy()
|
||||
data["name"] = f"Voice {i}"
|
||||
client.post("/api/voices", json=data)
|
||||
|
||||
# Test pagination
|
||||
response = client.get("/api/voices?page=1&limit=2")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 3
|
||||
assert len(data["list"]) == 2
|
||||
|
||||
def test_filter_voices_by_vendor(self, client, sample_voice_data):
|
||||
"""Test filtering voices by vendor"""
|
||||
# Create voice with specific vendor
|
||||
sample_voice_data["vendor"] = "FilterTestVendor"
|
||||
client.post("/api/voices", json=sample_voice_data)
|
||||
|
||||
response = client.get("/api/voices?vendor=FilterTestVendor")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
for voice in data["list"]:
|
||||
assert voice["vendor"] == "FilterTestVendor"
|
||||
|
||||
def test_filter_voices_by_language(self, client, sample_voice_data):
|
||||
"""Test filtering voices by language"""
|
||||
sample_voice_data["language"] = "en"
|
||||
client.post("/api/voices", json=sample_voice_data)
|
||||
|
||||
response = client.get("/api/voices?language=en")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
for voice in data["list"]:
|
||||
assert voice["language"] == "en"
|
||||
|
||||
def test_filter_voices_by_gender(self, client, sample_voice_data):
|
||||
"""Test filtering voices by gender"""
|
||||
sample_voice_data["gender"] = "Female"
|
||||
client.post("/api/voices", json=sample_voice_data)
|
||||
|
||||
response = client.get("/api/voices?gender=Female")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
for voice in data["list"]:
|
||||
assert voice["gender"] == "Female"
|
||||
Reference in New Issue
Block a user