Compare commits
3 Commits
86744f0842
...
3d8635670f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3d8635670f | ||
|
|
7012f8edaf | ||
|
|
727fe8a997 |
@@ -1,4 +1,4 @@
|
|||||||
FROM python:3.11-slim
|
FROM python:3.12-slim
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
@@ -12,6 +12,6 @@ COPY . .
|
|||||||
# 创建数据目录
|
# 创建数据目录
|
||||||
RUN mkdir -p /app/data
|
RUN mkdir -p /app/data
|
||||||
|
|
||||||
EXPOSE 8000
|
EXPOSE 8100
|
||||||
|
|
||||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"]
|
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8100", "--reload"]
|
||||||
|
|||||||
273
api/README.md
273
api/README.md
@@ -1,13 +1,13 @@
|
|||||||
# AI VideoAssistant Backend
|
# AI VideoAssistant Backend
|
||||||
|
|
||||||
Python 后端 API,配合前端 `ai-videoassistant-frontend` 使用。
|
Python 后端 API,配合前端 `web/` 模块使用。
|
||||||
|
|
||||||
## 快速开始
|
## 快速开始
|
||||||
|
|
||||||
### 1. 安装依赖
|
### 1. 安装依赖
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd ~/Code/ai-videoassistant-backend
|
cd api
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -25,44 +25,162 @@ python init_db.py
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 开发模式 (热重载)
|
# 开发模式 (热重载)
|
||||||
python -m uvicorn main:app --reload --host 0.0.0.0 --port 8000
|
python -m uvicorn main:app --reload --host 0.0.0.0 --port 8100
|
||||||
```
|
```
|
||||||
|
|
||||||
|
服务运行在: http://localhost:8100
|
||||||
|
|
||||||
### 4. 测试 API
|
### 4. 测试 API
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 健康检查
|
# 健康检查
|
||||||
curl http://localhost:8000/health
|
curl http://localhost:8100/health
|
||||||
|
|
||||||
# 获取助手列表
|
# 获取助手列表
|
||||||
curl http://localhost:8000/api/assistants
|
curl http://localhost:8100/api/assistants
|
||||||
|
|
||||||
# 获取声音列表
|
# 获取声音列表
|
||||||
curl http://localhost:8000/api/voices
|
curl http://localhost:8100/api/voices
|
||||||
|
|
||||||
# 获取通话历史
|
# 获取通话历史
|
||||||
curl http://localhost:8000/api/history
|
curl http://localhost:8100/api/history
|
||||||
```
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## API 文档
|
## API 文档
|
||||||
|
|
||||||
| 端点 | 方法 | 说明 |
|
完整 API 文档位于 [docs/](docs/) 目录:
|
||||||
|
|
||||||
|
| 模块 | 端点 | 方法 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| **Assistant** | `/api/assistants` | GET | 助手列表 |
|
||||||
|
| | | POST | 创建助手 |
|
||||||
|
| | `/api/assistants/{id}` | GET | 助手详情 |
|
||||||
|
| | | PUT | 更新助手 |
|
||||||
|
| | | DELETE | 删除助手 |
|
||||||
|
| **Voice** | `/api/voices` | GET | 声音库列表 |
|
||||||
|
| | | POST | 添加声音 |
|
||||||
|
| | `/api/voices/{id}` | GET | 声音详情 |
|
||||||
|
| | | PUT | 更新声音 |
|
||||||
|
| | | DELETE | 删除声音 |
|
||||||
|
| | `/api/voices/{id}/preview` | POST | 预览声音 |
|
||||||
|
| **LLM Models** | `/api/models/llm` | GET | LLM 模型列表 |
|
||||||
|
| | | POST | 添加模型 |
|
||||||
|
| | `/api/models/llm/{id}` | GET | 模型详情 |
|
||||||
|
| | | PUT | 更新模型 |
|
||||||
|
| | | DELETE | 删除模型 |
|
||||||
|
| | `/api/models/llm/{id}/test` | POST | 测试模型连接 |
|
||||||
|
| **ASR Models** | `/api/models/asr` | GET | ASR 模型列表 |
|
||||||
|
| | | POST | 添加模型 |
|
||||||
|
| | `/api/models/asr/{id}` | GET | 模型详情 |
|
||||||
|
| | | PUT | 更新模型 |
|
||||||
|
| | | DELETE | 删除模型 |
|
||||||
|
| | `/api/models/asr/{id}/test` | POST | 测试识别 |
|
||||||
|
| **History** | `/api/history` | GET | 通话历史列表 |
|
||||||
|
| | `/api/history/{id}` | GET | 通话详情 |
|
||||||
|
| | | PUT | 更新通话记录 |
|
||||||
|
| | | DELETE | 删除记录 |
|
||||||
|
| | `/api/history/{id}/transcripts` | POST | 添加转写 |
|
||||||
|
| | `/api/history/search` | GET | 搜索历史 |
|
||||||
|
| | `/api/history/stats` | GET | 统计数据 |
|
||||||
|
| **Knowledge** | `/api/knowledge/bases` | GET | 知识库列表 |
|
||||||
|
| | | POST | 创建知识库 |
|
||||||
|
| | `/api/knowledge/bases/{id}` | GET | 知识库详情 |
|
||||||
|
| | | PUT | 更新知识库 |
|
||||||
|
| | | DELETE | 删除知识库 |
|
||||||
|
| | `/api/knowledge/bases/{kb_id}/documents` | POST | 上传文档 |
|
||||||
|
| | `/api/knowledge/bases/{kb_id}/documents/{doc_id}` | DELETE | 删除文档 |
|
||||||
|
| | `/api/knowledge/bases/{kb_id}/documents/{doc_id}/index` | POST | 索引文档 |
|
||||||
|
| | `/api/knowledge/search` | POST | 知识搜索 |
|
||||||
|
| **Workflow** | `/api/workflows` | GET | 工作流列表 |
|
||||||
|
| | | POST | 创建工作流 |
|
||||||
|
| | `/api/workflows/{id}` | GET | 工作流详情 |
|
||||||
|
| | | PUT | 更新工作流 |
|
||||||
|
| | | DELETE | 删除工作流 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 数据模型
|
||||||
|
|
||||||
|
### Assistant (小助手)
|
||||||
|
|
||||||
|
| 字段 | 类型 | 说明 |
|
||||||
|------|------|------|
|
|------|------|------|
|
||||||
| `/api/assistants` | GET | 助手列表 |
|
| id | string | 助手 ID |
|
||||||
| `/api/assistants` | POST | 创建助手 |
|
| name | string | 助手名称 |
|
||||||
| `/api/assistants/{id}` | GET | 助手详情 |
|
| opener | string | 开场白 |
|
||||||
| `/api/assistants/{id}` | PUT | 更新助手 |
|
| prompt | string | 系统提示词 |
|
||||||
| `/api/assistants/{id}` | DELETE | 删除助手 |
|
| knowledgeBaseId | string | 关联知识库 ID |
|
||||||
| `/api/voices` | GET | 声音库列表 |
|
| language | string | 语言: zh/en |
|
||||||
| `/api/history` | GET | 通话历史列表 |
|
| voice | string | 声音 ID |
|
||||||
| `/api/history/{id}` | GET | 通话详情 |
|
| speed | float | 语速 (0.5-2.0) |
|
||||||
| `/api/history/{id}/transcripts` | POST | 添加转写 |
|
| hotwords | array | 热词列表 |
|
||||||
| `/api/history/{id}/audio/{turn}` | GET | 获取音频 |
|
| tools | array | 启用的工具列表 |
|
||||||
|
| llmModelId | string | LLM 模型 ID |
|
||||||
|
| asrModelId | string | ASR 模型 ID |
|
||||||
|
| embeddingModelId | string | Embedding 模型 ID |
|
||||||
|
| rerankModelId | string | Rerank 模型 ID |
|
||||||
|
|
||||||
|
### Voice (声音资源)
|
||||||
|
|
||||||
|
| 字段 | 类型 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| id | string | 声音 ID |
|
||||||
|
| name | string | 声音名称 |
|
||||||
|
| vendor | string | 厂商: Ali/Volcano/Minimax |
|
||||||
|
| gender | string | 性别: Male/Female |
|
||||||
|
| language | string | 语言: zh/en |
|
||||||
|
| model | string | 厂商模型标识 |
|
||||||
|
| voice_key | string | 厂商 voice_key |
|
||||||
|
| speed | float | 语速 |
|
||||||
|
| gain | int | 增益 (dB) |
|
||||||
|
| pitch | int | 音调 |
|
||||||
|
|
||||||
|
### LLMModel (模型接入)
|
||||||
|
|
||||||
|
| 字段 | 类型 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| id | string | 模型 ID |
|
||||||
|
| name | string | 模型名称 |
|
||||||
|
| vendor | string | 厂商 |
|
||||||
|
| type | string | 类型: text/embedding/rerank |
|
||||||
|
| base_url | string | API 地址 |
|
||||||
|
| api_key | string | API 密钥 |
|
||||||
|
| model_name | string | 模型名称 |
|
||||||
|
| temperature | float | 温度参数 |
|
||||||
|
|
||||||
|
### ASRModel (语音识别)
|
||||||
|
|
||||||
|
| 字段 | 类型 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| id | string | 模型 ID |
|
||||||
|
| name | string | 模型名称 |
|
||||||
|
| vendor | string | 厂商 |
|
||||||
|
| language | string | 语言: zh/en/Multi-lingual |
|
||||||
|
| base_url | string | API 地址 |
|
||||||
|
| api_key | string | API 密钥 |
|
||||||
|
| hotwords | array | 热词列表 |
|
||||||
|
|
||||||
|
### CallRecord (通话记录)
|
||||||
|
|
||||||
|
| 字段 | 类型 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| id | string | 记录 ID |
|
||||||
|
| assistant_id | string | 助手 ID |
|
||||||
|
| source | string | 来源: debug/external |
|
||||||
|
| status | string | 状态: connected/missed/failed |
|
||||||
|
| started_at | string | 开始时间 |
|
||||||
|
| duration_seconds | int | 通话时长 |
|
||||||
|
| summary | string | 通话摘要 |
|
||||||
|
| transcripts | array | 对话转写 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## 使用 Docker 启动
|
## 使用 Docker 启动
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd ~/Code/ai-videoassistant-backend
|
cd api
|
||||||
|
|
||||||
# 启动所有服务
|
# 启动所有服务
|
||||||
docker-compose up -d
|
docker-compose up -d
|
||||||
@@ -71,33 +189,144 @@ docker-compose up -d
|
|||||||
docker-compose logs -f backend
|
docker-compose logs -f backend
|
||||||
```
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## 目录结构
|
## 目录结构
|
||||||
|
|
||||||
```
|
```
|
||||||
backend/
|
api/
|
||||||
├── app/
|
├── app/
|
||||||
│ ├── __init__.py
|
│ ├── __init__.py
|
||||||
│ ├── main.py # FastAPI 入口
|
│ ├── main.py # FastAPI 入口
|
||||||
│ ├── db.py # SQLite 连接
|
│ ├── db.py # SQLite 连接
|
||||||
│ ├── models.py # 数据模型
|
│ ├── models.py # SQLAlchemy 数据模型
|
||||||
│ ├── schemas.py # Pydantic 模型
|
│ ├── schemas.py # Pydantic 模型
|
||||||
│ ├── storage.py # MinIO 存储
|
│ ├── storage.py # MinIO 存储
|
||||||
|
│ ├── vector_store.py # 向量存储
|
||||||
│ └── routers/
|
│ └── routers/
|
||||||
│ ├── __init__.py
|
│ ├── __init__.py
|
||||||
│ ├── assistants.py # 助手 API
|
│ ├── assistants.py # 助手 API
|
||||||
│ └── history.py # 通话记录 API
|
│ ├── history.py # 通话记录 API
|
||||||
|
│ └── knowledge.py # 知识库 API
|
||||||
├── data/ # 数据库文件
|
├── data/ # 数据库文件
|
||||||
|
├── docs/ # API 文档
|
||||||
├── requirements.txt
|
├── requirements.txt
|
||||||
├── .env
|
├── .env
|
||||||
|
├── init_db.py
|
||||||
|
├── main.py
|
||||||
└── docker-compose.yml
|
└── docker-compose.yml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## 环境变量
|
## 环境变量
|
||||||
|
|
||||||
| 变量 | 默认值 | 说明 |
|
| 变量 | 默认值 | 说明 |
|
||||||
|------|--------|------|
|
|------|--------|------|
|
||||||
|
| `PORT` | `8100` | 服务端口 |
|
||||||
| `DATABASE_URL` | `sqlite:///./data/app.db` | 数据库连接 |
|
| `DATABASE_URL` | `sqlite:///./data/app.db` | 数据库连接 |
|
||||||
| `MINIO_ENDPOINT` | `localhost:9000` | MinIO 地址 |
|
| `MINIO_ENDPOINT` | `localhost:9000` | MinIO 地址 |
|
||||||
| `MINIO_ACCESS_KEY` | `admin` | MinIO 密钥 |
|
| `MINIO_ACCESS_KEY` | `admin` | MinIO 密钥 |
|
||||||
| `MINIO_SECRET_KEY` | `password123` | MinIO 密码 |
|
| `MINIO_SECRET_KEY` | `password123` | MinIO 密码 |
|
||||||
| `MINIO_BUCKET` | `ai-audio` | 存储桶名称 |
|
| `MINIO_BUCKET` | `ai-audio` | 存储桶名称 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 数据库迁移
|
||||||
|
|
||||||
|
开发环境重新创建数据库:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
rm -f api/data/app.db
|
||||||
|
python api/init_db.py
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 测试
|
||||||
|
|
||||||
|
### 安装测试依赖
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd api
|
||||||
|
pip install pytest pytest-cov -q
|
||||||
|
```
|
||||||
|
|
||||||
|
### 运行所有测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Windows
|
||||||
|
run_tests.bat
|
||||||
|
|
||||||
|
# 或使用 pytest
|
||||||
|
pytest tests/ -v
|
||||||
|
```
|
||||||
|
|
||||||
|
### 运行特定测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 只测试声音 API
|
||||||
|
pytest tests/test_voices.py -v
|
||||||
|
|
||||||
|
# 只测试助手 API
|
||||||
|
pytest tests/test_assistants.py -v
|
||||||
|
|
||||||
|
# 只测试历史记录 API
|
||||||
|
pytest tests/test_history.py -v
|
||||||
|
|
||||||
|
# 只测试知识库 API
|
||||||
|
pytest tests/test_knowledge.py -v
|
||||||
|
```
|
||||||
|
|
||||||
|
### 测试覆盖率
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pytest tests/ --cov=app --cov-report=html
|
||||||
|
# 查看报告: open htmlcov/index.html
|
||||||
|
```
|
||||||
|
|
||||||
|
### 测试目录结构
|
||||||
|
|
||||||
|
```
|
||||||
|
tests/
|
||||||
|
├── __init__.py
|
||||||
|
├── conftest.py # pytest fixtures
|
||||||
|
├── test_voices.py # 声音 API 测试
|
||||||
|
├── test_assistants.py # 助手 API 测试
|
||||||
|
├── test_history.py # 历史记录 API 测试
|
||||||
|
└── test_knowledge.py # 知识库 API 测试
|
||||||
|
```
|
||||||
|
|
||||||
|
### 测试用例统计
|
||||||
|
|
||||||
|
| 模块 | 测试用例数 |
|
||||||
|
|------|-----------|
|
||||||
|
| Voice | 13 |
|
||||||
|
| Assistant | 14 |
|
||||||
|
| History | 18 |
|
||||||
|
| Knowledge | 19 |
|
||||||
|
| **总计** | **64** |
|
||||||
|
|
||||||
|
### CI/CD 示例 (.github/workflows/test.yml)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
name: Tests
|
||||||
|
|
||||||
|
on: [push, pull_request]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
pip install -r api/requirements.txt
|
||||||
|
pip install pytest pytest-cov
|
||||||
|
- name: Run tests
|
||||||
|
run: pytest api/tests/ -v --cov=app
|
||||||
|
```
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.orm import sessionmaker, DeclarativeBase
|
from sqlalchemy.orm import sessionmaker, DeclarativeBase
|
||||||
|
import os
|
||||||
|
|
||||||
DATABASE_URL = "sqlite:///./data/app.db"
|
# 使用绝对路径
|
||||||
|
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
DATABASE_URL = f"sqlite:///{os.path.join(BASE_DIR, 'data', 'app.db')}"
|
||||||
|
|
||||||
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
|
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
|
||||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from contextlib import asynccontextmanager
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from .db import Base, engine
|
from .db import Base, engine
|
||||||
from .routers import assistants, history
|
from .routers import assistants, history, knowledge, llm, asr, tools
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@@ -33,6 +33,10 @@ app.add_middleware(
|
|||||||
# 路由
|
# 路由
|
||||||
app.include_router(assistants.router, prefix="/api")
|
app.include_router(assistants.router, prefix="/api")
|
||||||
app.include_router(history.router, prefix="/api")
|
app.include_router(history.router, prefix="/api")
|
||||||
|
app.include_router(knowledge.router, prefix="/api")
|
||||||
|
app.include_router(llm.router, prefix="/api")
|
||||||
|
app.include_router(asr.router, prefix="/api")
|
||||||
|
app.include_router(tools.router, prefix="/api")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
@@ -43,30 +47,3 @@ def root():
|
|||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
def health():
|
def health():
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
# 初始化默认数据
|
|
||||||
@app.on_event("startup")
|
|
||||||
def init_default_data():
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from .db import SessionLocal
|
|
||||||
from .models import Voice
|
|
||||||
|
|
||||||
db = SessionLocal()
|
|
||||||
try:
|
|
||||||
# 检查是否已有数据
|
|
||||||
if db.query(Voice).count() == 0:
|
|
||||||
# 插入默认声音
|
|
||||||
voices = [
|
|
||||||
Voice(id="v1", name="Xiaoyun", vendor="Ali", gender="Female", language="zh", description="Gentle and professional."),
|
|
||||||
Voice(id="v2", name="Kevin", vendor="Volcano", gender="Male", language="en", description="Deep and authoritative."),
|
|
||||||
Voice(id="v3", name="Abby", vendor="Minimax", gender="Female", language="en", description="Cheerful and lively."),
|
|
||||||
Voice(id="v4", name="Guang", vendor="Ali", gender="Male", language="zh", description="Standard newscast style."),
|
|
||||||
Voice(id="v5", name="Doubao", vendor="Volcano", gender="Female", language="zh", description="Cute and young."),
|
|
||||||
]
|
|
||||||
for v in voices:
|
|
||||||
db.add(v)
|
|
||||||
db.commit()
|
|
||||||
print("✅ 默认声音数据已初始化")
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from sqlalchemy import String, Integer, DateTime, Text, Float, ForeignKey, JSON
|
from sqlalchemy import String, Integer, DateTime, Text, Float, ForeignKey, JSON, Enum
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from .db import Base
|
from .db import Base
|
||||||
@@ -15,18 +15,72 @@ class User(Base):
|
|||||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||||
|
|
||||||
|
|
||||||
|
# ============ Voice ============
|
||||||
class Voice(Base):
|
class Voice(Base):
|
||||||
__tablename__ = "voices"
|
__tablename__ = "voices"
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||||
|
user_id: Mapped[Optional[int]] = mapped_column(Integer, ForeignKey("users.id"), index=True, nullable=True)
|
||||||
name: Mapped[str] = mapped_column(String(128), nullable=False)
|
name: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||||
vendor: Mapped[str] = mapped_column(String(64), nullable=False)
|
vendor: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
gender: Mapped[str] = mapped_column(String(32), nullable=False)
|
gender: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||||
language: Mapped[str] = mapped_column(String(16), nullable=False)
|
language: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||||
description: Mapped[str] = mapped_column(String(255), nullable=False)
|
description: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
voice_params: Mapped[dict] = mapped_column(JSON, default=dict)
|
model: Mapped[Optional[str]] = mapped_column(String(128), nullable=True) # 厂商语音模型标识
|
||||||
|
voice_key: Mapped[Optional[str]] = mapped_column(String(128), nullable=True) # 厂商voice_key
|
||||||
|
speed: Mapped[float] = mapped_column(Float, default=1.0)
|
||||||
|
gain: Mapped[int] = mapped_column(Integer, default=0)
|
||||||
|
pitch: Mapped[int] = mapped_column(Integer, default=0)
|
||||||
|
enabled: Mapped[bool] = mapped_column(default=True)
|
||||||
|
is_system: Mapped[bool] = mapped_column(default=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||||
|
|
||||||
|
user = relationship("User", foreign_keys=[user_id])
|
||||||
|
|
||||||
|
|
||||||
|
# ============ LLM Model ============
|
||||||
|
class LLMModel(Base):
|
||||||
|
__tablename__ = "llm_models"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||||
|
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), index=True)
|
||||||
|
name: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||||
|
vendor: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
type: Mapped[str] = mapped_column(String(32), nullable=False) # text/embedding/rerank
|
||||||
|
base_url: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||||
|
api_key: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||||
|
model_name: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||||
|
temperature: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||||
|
context_length: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||||
|
enabled: Mapped[bool] = mapped_column(default=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||||
|
|
||||||
|
user = relationship("User")
|
||||||
|
|
||||||
|
|
||||||
|
# ============ ASR Model ============
|
||||||
|
class ASRModel(Base):
|
||||||
|
__tablename__ = "asr_models"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||||
|
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), index=True)
|
||||||
|
name: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||||
|
vendor: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
language: Mapped[str] = mapped_column(String(32), nullable=False) # zh/en/Multi-lingual
|
||||||
|
base_url: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||||
|
api_key: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||||
|
model_name: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||||
|
hotwords: Mapped[dict] = mapped_column(JSON, default=list)
|
||||||
|
enable_punctuation: Mapped[bool] = mapped_column(default=True)
|
||||||
|
enable_normalization: Mapped[bool] = mapped_column(default=True)
|
||||||
|
enabled: Mapped[bool] = mapped_column(default=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||||
|
|
||||||
|
user = relationship("User")
|
||||||
|
|
||||||
|
|
||||||
|
# ============ Assistant ============
|
||||||
class Assistant(Base):
|
class Assistant(Base):
|
||||||
__tablename__ = "assistants"
|
__tablename__ = "assistants"
|
||||||
|
|
||||||
@@ -46,6 +100,11 @@ class Assistant(Base):
|
|||||||
config_mode: Mapped[str] = mapped_column(String(32), default="platform")
|
config_mode: Mapped[str] = mapped_column(String(32), default="platform")
|
||||||
api_url: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
api_url: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||||
api_key: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
api_key: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||||
|
# 模型关联
|
||||||
|
llm_model_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||||
|
asr_model_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||||
|
embedding_model_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||||
|
rerank_model_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||||
|
|
||||||
@@ -53,6 +112,7 @@ class Assistant(Base):
|
|||||||
call_records = relationship("CallRecord", back_populates="assistant")
|
call_records = relationship("CallRecord", back_populates="assistant")
|
||||||
|
|
||||||
|
|
||||||
|
# ============ Knowledge Base ============
|
||||||
class KnowledgeBase(Base):
|
class KnowledgeBase(Base):
|
||||||
__tablename__ = "knowledge_bases"
|
__tablename__ = "knowledge_bases"
|
||||||
|
|
||||||
@@ -92,6 +152,7 @@ class KnowledgeDocument(Base):
|
|||||||
kb = relationship("KnowledgeBase", back_populates="documents")
|
kb = relationship("KnowledgeBase", back_populates="documents")
|
||||||
|
|
||||||
|
|
||||||
|
# ============ Workflow ============
|
||||||
class Workflow(Base):
|
class Workflow(Base):
|
||||||
__tablename__ = "workflows"
|
__tablename__ = "workflows"
|
||||||
|
|
||||||
@@ -108,6 +169,7 @@ class Workflow(Base):
|
|||||||
user = relationship("User")
|
user = relationship("User")
|
||||||
|
|
||||||
|
|
||||||
|
# ============ Call Record ============
|
||||||
class CallRecord(Base):
|
class CallRecord(Base):
|
||||||
__tablename__ = "call_records"
|
__tablename__ = "call_records"
|
||||||
|
|
||||||
|
|||||||
@@ -3,9 +3,15 @@ from fastapi import APIRouter
|
|||||||
from . import assistants
|
from . import assistants
|
||||||
from . import history
|
from . import history
|
||||||
from . import knowledge
|
from . import knowledge
|
||||||
|
from . import llm
|
||||||
|
from . import asr
|
||||||
|
from . import tools
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
router.include_router(assistants.router)
|
router.include_router(assistants.router)
|
||||||
router.include_router(history.router)
|
router.include_router(history.router)
|
||||||
router.include_router(knowledge.router)
|
router.include_router(knowledge.router)
|
||||||
|
router.include_router(llm.router)
|
||||||
|
router.include_router(asr.router)
|
||||||
|
router.include_router(tools.router)
|
||||||
|
|||||||
268
api/app/routers/asr.py
Normal file
268
api/app/routers/asr.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import List, Optional
|
||||||
|
import uuid
|
||||||
|
import httpx
|
||||||
|
import time
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from ..db import get_db
|
||||||
|
from ..models import ASRModel
|
||||||
|
from ..schemas import (
|
||||||
|
ASRModelCreate, ASRModelUpdate, ASRModelOut,
|
||||||
|
ASRTestRequest, ASRTestResponse, ListResponse
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/asr", tags=["ASR Models"])
|
||||||
|
|
||||||
|
|
||||||
|
# ============ ASR Models CRUD ============
|
||||||
|
@router.get("", response_model=ListResponse)
|
||||||
|
def list_asr_models(
|
||||||
|
language: Optional[str] = None,
|
||||||
|
enabled: Optional[bool] = None,
|
||||||
|
page: int = 1,
|
||||||
|
limit: int = 50,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""获取ASR模型列表"""
|
||||||
|
query = db.query(ASRModel)
|
||||||
|
|
||||||
|
if language:
|
||||||
|
query = query.filter(ASRModel.language == language)
|
||||||
|
if enabled is not None:
|
||||||
|
query = query.filter(ASRModel.enabled == enabled)
|
||||||
|
|
||||||
|
total = query.count()
|
||||||
|
models = query.order_by(ASRModel.created_at.desc()) \
|
||||||
|
.offset((page-1)*limit).limit(limit).all()
|
||||||
|
|
||||||
|
return {"total": total, "page": page, "limit": limit, "list": models}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{id}", response_model=ASRModelOut)
|
||||||
|
def get_asr_model(id: str, db: Session = Depends(get_db)):
|
||||||
|
"""获取单个ASR模型详情"""
|
||||||
|
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||||||
|
if not model:
|
||||||
|
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=ASRModelOut)
|
||||||
|
def create_asr_model(data: ASRModelCreate, db: Session = Depends(get_db)):
|
||||||
|
"""创建ASR模型"""
|
||||||
|
asr_model = ASRModel(
|
||||||
|
id=data.id or str(uuid.uuid4())[:8],
|
||||||
|
user_id=1, # 默认用户
|
||||||
|
name=data.name,
|
||||||
|
vendor=data.vendor,
|
||||||
|
language=data.language,
|
||||||
|
base_url=data.base_url,
|
||||||
|
api_key=data.api_key,
|
||||||
|
model_name=data.model_name,
|
||||||
|
hotwords=data.hotwords,
|
||||||
|
enable_punctuation=data.enable_punctuation,
|
||||||
|
enable_normalization=data.enable_normalization,
|
||||||
|
enabled=data.enabled,
|
||||||
|
)
|
||||||
|
db.add(asr_model)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(asr_model)
|
||||||
|
return asr_model
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{id}", response_model=ASRModelOut)
|
||||||
|
def update_asr_model(id: str, data: ASRModelUpdate, db: Session = Depends(get_db)):
|
||||||
|
"""更新ASR模型"""
|
||||||
|
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||||||
|
if not model:
|
||||||
|
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||||||
|
|
||||||
|
update_data = data.model_dump(exclude_unset=True)
|
||||||
|
for field, value in update_data.items():
|
||||||
|
setattr(model, field, value)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.refresh(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{id}")
|
||||||
|
def delete_asr_model(id: str, db: Session = Depends(get_db)):
|
||||||
|
"""删除ASR模型"""
|
||||||
|
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||||||
|
if not model:
|
||||||
|
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||||||
|
db.delete(model)
|
||||||
|
db.commit()
|
||||||
|
return {"message": "Deleted successfully"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{id}/test", response_model=ASRTestResponse)
|
||||||
|
def test_asr_model(
|
||||||
|
id: str,
|
||||||
|
request: Optional[ASRTestRequest] = None,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""测试ASR模型"""
|
||||||
|
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||||||
|
if not model:
|
||||||
|
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 根据不同的厂商构造不同的请求
|
||||||
|
if model.vendor.lower() in ["siliconflow", "paraformer"]:
|
||||||
|
# SiliconFlow/Paraformer 格式
|
||||||
|
payload = {
|
||||||
|
"model": model.model_name or "paraformer-v2",
|
||||||
|
"input": {},
|
||||||
|
"parameters": {
|
||||||
|
"hotwords": " ".join(model.hotwords) if model.hotwords else "",
|
||||||
|
"enable_punctuation": model.enable_punctuation,
|
||||||
|
"enable_normalization": model.enable_normalization,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果有音频数据
|
||||||
|
if request and request.audio_data:
|
||||||
|
payload["input"]["file_urls"] = []
|
||||||
|
elif request and request.audio_url:
|
||||||
|
payload["input"]["url"] = request.audio_url
|
||||||
|
|
||||||
|
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||||
|
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
response = client.post(
|
||||||
|
f"{model.base_url}/asr",
|
||||||
|
json=payload,
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
elif model.vendor.lower() == "openai":
|
||||||
|
# OpenAI Whisper 格式
|
||||||
|
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||||
|
|
||||||
|
# 准备文件
|
||||||
|
files = {}
|
||||||
|
if request and request.audio_data:
|
||||||
|
audio_bytes = base64.b64decode(request.audio_data)
|
||||||
|
files = {"file": ("audio.wav", audio_bytes, "audio/wav")}
|
||||||
|
data = {"model": model.model_name or "whisper-1"}
|
||||||
|
elif request and request.audio_url:
|
||||||
|
files = {"file": ("audio.wav", httpx.get(request.audio_url).content, "audio/wav")}
|
||||||
|
data = {"model": model.model_name or "whisper-1"}
|
||||||
|
else:
|
||||||
|
return ASRTestResponse(
|
||||||
|
success=False,
|
||||||
|
error="No audio data or URL provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
response = client.post(
|
||||||
|
f"{model.base_url}/audio/transcriptions",
|
||||||
|
files=files,
|
||||||
|
data=data,
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
result = {"results": [{"transcript": result.get("text", "")}]}
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 通用格式(可根据需要扩展)
|
||||||
|
return ASRTestResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"Unsupported vendor: {model.vendor}"
|
||||||
|
)
|
||||||
|
|
||||||
|
latency_ms = int((time.time() - start_time) * 1000)
|
||||||
|
|
||||||
|
# 解析结果
|
||||||
|
if result_data := result.get("results", [{}])[0]:
|
||||||
|
transcript = result_data.get("transcript", "")
|
||||||
|
return ASRTestResponse(
|
||||||
|
success=True,
|
||||||
|
transcript=transcript,
|
||||||
|
language=result_data.get("language", model.language),
|
||||||
|
confidence=result_data.get("confidence"),
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ASRTestResponse(
|
||||||
|
success=False,
|
||||||
|
message="No transcript in response",
|
||||||
|
latency_ms=latency_ms
|
||||||
|
)
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
return ASRTestResponse(
|
||||||
|
success=False,
|
||||||
|
error=f"HTTP Error: {e.response.status_code} - {e.response.text[:200]}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return ASRTestResponse(
|
||||||
|
success=False,
|
||||||
|
error=str(e)[:200]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{id}/transcribe")
|
||||||
|
def transcribe_audio(
|
||||||
|
id: str,
|
||||||
|
audio_url: Optional[str] = None,
|
||||||
|
audio_data: Optional[str] = None,
|
||||||
|
hotwords: Optional[List[str]] = None,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""转写音频"""
|
||||||
|
model = db.query(ASRModel).filter(ASRModel.id == id).first()
|
||||||
|
if not model:
|
||||||
|
raise HTTPException(status_code=404, detail="ASR Model not found")
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = {
|
||||||
|
"model": model.model_name or "paraformer-v2",
|
||||||
|
"input": {},
|
||||||
|
"parameters": {
|
||||||
|
"hotwords": " ".join(hotwords or model.hotwords or []),
|
||||||
|
"enable_punctuation": model.enable_punctuation,
|
||||||
|
"enable_normalization": model.enable_normalization,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||||
|
|
||||||
|
if audio_url:
|
||||||
|
payload["input"]["url"] = audio_url
|
||||||
|
elif audio_data:
|
||||||
|
payload["input"]["file_urls"] = []
|
||||||
|
|
||||||
|
with httpx.Client(timeout=120.0) as client:
|
||||||
|
response = client.post(
|
||||||
|
f"{model.base_url}/asr",
|
||||||
|
json=payload,
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
if result_data := result.get("results", [{}])[0]:
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"transcript": result_data.get("transcript", ""),
|
||||||
|
"language": result_data.get("language", model.language),
|
||||||
|
"confidence": result_data.get("confidence"),
|
||||||
|
}
|
||||||
|
|
||||||
|
return {"success": False, "error": "No transcript in response"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
@@ -8,7 +8,7 @@ from ..db import get_db
|
|||||||
from ..models import Assistant, Voice, Workflow
|
from ..models import Assistant, Voice, Workflow
|
||||||
from ..schemas import (
|
from ..schemas import (
|
||||||
AssistantCreate, AssistantUpdate, AssistantOut,
|
AssistantCreate, AssistantUpdate, AssistantOut,
|
||||||
VoiceOut,
|
VoiceCreate, VoiceUpdate, VoiceOut,
|
||||||
WorkflowCreate, WorkflowUpdate, WorkflowOut
|
WorkflowCreate, WorkflowUpdate, WorkflowOut
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -16,11 +16,88 @@ router = APIRouter()
|
|||||||
|
|
||||||
|
|
||||||
# ============ Voices ============
|
# ============ Voices ============
|
||||||
@router.get("/voices", response_model=List[VoiceOut])
|
@router.get("/voices")
|
||||||
def list_voices(db: Session = Depends(get_db)):
|
def list_voices(
|
||||||
|
vendor: Optional[str] = None,
|
||||||
|
language: Optional[str] = None,
|
||||||
|
gender: Optional[str] = None,
|
||||||
|
page: int = 1,
|
||||||
|
limit: int = 50,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
"""获取声音库列表"""
|
"""获取声音库列表"""
|
||||||
voices = db.query(Voice).all()
|
query = db.query(Voice)
|
||||||
return voices
|
if vendor:
|
||||||
|
query = query.filter(Voice.vendor == vendor)
|
||||||
|
if language:
|
||||||
|
query = query.filter(Voice.language == language)
|
||||||
|
if gender:
|
||||||
|
query = query.filter(Voice.gender == gender)
|
||||||
|
|
||||||
|
total = query.count()
|
||||||
|
voices = query.order_by(Voice.created_at.desc()) \
|
||||||
|
.offset((page-1)*limit).limit(limit).all()
|
||||||
|
return {"total": total, "page": page, "limit": limit, "list": voices}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/voices", response_model=VoiceOut)
|
||||||
|
def create_voice(data: VoiceCreate, db: Session = Depends(get_db)):
|
||||||
|
"""创建声音"""
|
||||||
|
voice = Voice(
|
||||||
|
id=data.id or str(uuid.uuid4())[:8],
|
||||||
|
user_id=1,
|
||||||
|
name=data.name,
|
||||||
|
vendor=data.vendor,
|
||||||
|
gender=data.gender,
|
||||||
|
language=data.language,
|
||||||
|
description=data.description,
|
||||||
|
model=data.model,
|
||||||
|
voice_key=data.voice_key,
|
||||||
|
speed=data.speed,
|
||||||
|
gain=data.gain,
|
||||||
|
pitch=data.pitch,
|
||||||
|
enabled=data.enabled,
|
||||||
|
)
|
||||||
|
db.add(voice)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(voice)
|
||||||
|
return voice
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/voices/{id}", response_model=VoiceOut)
|
||||||
|
def get_voice(id: str, db: Session = Depends(get_db)):
|
||||||
|
"""获取单个声音详情"""
|
||||||
|
voice = db.query(Voice).filter(Voice.id == id).first()
|
||||||
|
if not voice:
|
||||||
|
raise HTTPException(status_code=404, detail="Voice not found")
|
||||||
|
return voice
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/voices/{id}", response_model=VoiceOut)
|
||||||
|
def update_voice(id: str, data: VoiceUpdate, db: Session = Depends(get_db)):
|
||||||
|
"""更新声音"""
|
||||||
|
voice = db.query(Voice).filter(Voice.id == id).first()
|
||||||
|
if not voice:
|
||||||
|
raise HTTPException(status_code=404, detail="Voice not found")
|
||||||
|
|
||||||
|
update_data = data.model_dump(exclude_unset=True)
|
||||||
|
for field, value in update_data.items():
|
||||||
|
setattr(voice, field, value)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.refresh(voice)
|
||||||
|
return voice
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/voices/{id}")
|
||||||
|
def delete_voice(id: str, db: Session = Depends(get_db)):
|
||||||
|
"""删除声音"""
|
||||||
|
voice = db.query(Voice).filter(Voice.id == id).first()
|
||||||
|
if not voice:
|
||||||
|
raise HTTPException(status_code=404, detail="Voice not found")
|
||||||
|
db.delete(voice)
|
||||||
|
db.commit()
|
||||||
|
return {"message": "Deleted successfully"}
|
||||||
|
|
||||||
|
|
||||||
# ============ Assistants ============
|
# ============ Assistants ============
|
||||||
@@ -103,10 +180,17 @@ def delete_assistant(id: str, db: Session = Depends(get_db)):
|
|||||||
|
|
||||||
# ============ Workflows ============
|
# ============ Workflows ============
|
||||||
@router.get("/workflows", response_model=List[WorkflowOut])
|
@router.get("/workflows", response_model=List[WorkflowOut])
|
||||||
def list_workflows(db: Session = Depends(get_db)):
|
def list_workflows(
|
||||||
|
page: int = 1,
|
||||||
|
limit: int = 50,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
"""获取工作流列表"""
|
"""获取工作流列表"""
|
||||||
workflows = db.query(Workflow).all()
|
query = db.query(Workflow)
|
||||||
return workflows
|
total = query.count()
|
||||||
|
workflows = query.order_by(Workflow.created_at.desc()) \
|
||||||
|
.offset((page-1)*limit).limit(limit).all()
|
||||||
|
return {"total": total, "page": page, "limit": limit, "list": workflows}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/workflows", response_model=WorkflowOut)
|
@router.post("/workflows", response_model=WorkflowOut)
|
||||||
@@ -129,6 +213,15 @@ def create_workflow(data: WorkflowCreate, db: Session = Depends(get_db)):
|
|||||||
return workflow
|
return workflow
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/workflows/{id}", response_model=WorkflowOut)
|
||||||
|
def get_workflow(id: str, db: Session = Depends(get_db)):
|
||||||
|
"""获取单个工作流"""
|
||||||
|
workflow = db.query(Workflow).filter(Workflow.id == id).first()
|
||||||
|
if not workflow:
|
||||||
|
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||||
|
return workflow
|
||||||
|
|
||||||
|
|
||||||
@router.put("/workflows/{id}", response_model=WorkflowOut)
|
@router.put("/workflows/{id}", response_model=WorkflowOut)
|
||||||
def update_workflow(id: str, data: WorkflowUpdate, db: Session = Depends(get_db)):
|
def update_workflow(id: str, data: WorkflowUpdate, db: Session = Depends(get_db)):
|
||||||
"""更新工作流"""
|
"""更新工作流"""
|
||||||
|
|||||||
206
api/app/routers/llm.py
Normal file
206
api/app/routers/llm.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import List, Optional
|
||||||
|
import uuid
|
||||||
|
import httpx
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from ..db import get_db
|
||||||
|
from ..models import LLMModel
|
||||||
|
from ..schemas import (
|
||||||
|
LLMModelCreate, LLMModelUpdate, LLMModelOut,
|
||||||
|
LLMModelTestResponse, ListResponse
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/llm", tags=["LLM Models"])
|
||||||
|
|
||||||
|
|
||||||
|
# ============ LLM Models CRUD ============
|
||||||
|
@router.get("", response_model=ListResponse)
|
||||||
|
def list_llm_models(
|
||||||
|
model_type: Optional[str] = None,
|
||||||
|
enabled: Optional[bool] = None,
|
||||||
|
page: int = 1,
|
||||||
|
limit: int = 50,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""获取LLM模型列表"""
|
||||||
|
query = db.query(LLMModel)
|
||||||
|
|
||||||
|
if model_type:
|
||||||
|
query = query.filter(LLMModel.type == model_type)
|
||||||
|
if enabled is not None:
|
||||||
|
query = query.filter(LLMModel.enabled == enabled)
|
||||||
|
|
||||||
|
total = query.count()
|
||||||
|
models = query.order_by(LLMModel.created_at.desc()) \
|
||||||
|
.offset((page-1)*limit).limit(limit).all()
|
||||||
|
|
||||||
|
return {"total": total, "page": page, "limit": limit, "list": models}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{id}", response_model=LLMModelOut)
|
||||||
|
def get_llm_model(id: str, db: Session = Depends(get_db)):
|
||||||
|
"""获取单个LLM模型详情"""
|
||||||
|
model = db.query(LLMModel).filter(LLMModel.id == id).first()
|
||||||
|
if not model:
|
||||||
|
raise HTTPException(status_code=404, detail="LLM Model not found")
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=LLMModelOut)
|
||||||
|
def create_llm_model(data: LLMModelCreate, db: Session = Depends(get_db)):
|
||||||
|
"""创建LLM模型"""
|
||||||
|
llm_model = LLMModel(
|
||||||
|
id=data.id or str(uuid.uuid4())[:8],
|
||||||
|
user_id=1, # 默认用户
|
||||||
|
name=data.name,
|
||||||
|
vendor=data.vendor,
|
||||||
|
type=data.type.value if hasattr(data.type, 'value') else data.type,
|
||||||
|
base_url=data.base_url,
|
||||||
|
api_key=data.api_key,
|
||||||
|
model_name=data.model_name,
|
||||||
|
temperature=data.temperature,
|
||||||
|
context_length=data.context_length,
|
||||||
|
enabled=data.enabled,
|
||||||
|
)
|
||||||
|
db.add(llm_model)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(llm_model)
|
||||||
|
return llm_model
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{id}", response_model=LLMModelOut)
|
||||||
|
def update_llm_model(id: str, data: LLMModelUpdate, db: Session = Depends(get_db)):
|
||||||
|
"""更新LLM模型"""
|
||||||
|
model = db.query(LLMModel).filter(LLMModel.id == id).first()
|
||||||
|
if not model:
|
||||||
|
raise HTTPException(status_code=404, detail="LLM Model not found")
|
||||||
|
|
||||||
|
update_data = data.model_dump(exclude_unset=True)
|
||||||
|
for field, value in update_data.items():
|
||||||
|
setattr(model, field, value)
|
||||||
|
|
||||||
|
model.updated_at = datetime.utcnow()
|
||||||
|
db.commit()
|
||||||
|
db.refresh(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{id}")
|
||||||
|
def delete_llm_model(id: str, db: Session = Depends(get_db)):
|
||||||
|
"""删除LLM模型"""
|
||||||
|
model = db.query(LLMModel).filter(LLMModel.id == id).first()
|
||||||
|
if not model:
|
||||||
|
raise HTTPException(status_code=404, detail="LLM Model not found")
|
||||||
|
db.delete(model)
|
||||||
|
db.commit()
|
||||||
|
return {"message": "Deleted successfully"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{id}/test", response_model=LLMModelTestResponse)
|
||||||
|
def test_llm_model(id: str, db: Session = Depends(get_db)):
|
||||||
|
"""测试LLM模型连接"""
|
||||||
|
model = db.query(LLMModel).filter(LLMModel.id == id).first()
|
||||||
|
if not model:
|
||||||
|
raise HTTPException(status_code=404, detail="LLM Model not found")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
# 构造测试请求
|
||||||
|
test_messages = [{"role": "user", "content": "Hello, please reply with 'OK'."}]
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": model.model_name or "gpt-3.5-turbo",
|
||||||
|
"messages": test_messages,
|
||||||
|
"max_tokens": 10,
|
||||||
|
"temperature": 0.1,
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||||
|
|
||||||
|
with httpx.Client(timeout=30.0) as client:
|
||||||
|
response = client.post(
|
||||||
|
f"{model.base_url}/chat/completions",
|
||||||
|
json=payload,
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
latency_ms = int((time.time() - start_time) * 1000)
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
if result.get("choices"):
|
||||||
|
return LLMModelTestResponse(
|
||||||
|
success=True,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
message="Connection successful"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return LLMModelTestResponse(
|
||||||
|
success=False,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
message="Unexpected response format"
|
||||||
|
)
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
return LLMModelTestResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"HTTP Error: {e.response.status_code} - {e.response.text[:200]}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return LLMModelTestResponse(
|
||||||
|
success=False,
|
||||||
|
message=str(e)[:200]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{id}/chat")
|
||||||
|
def chat_with_llm(
|
||||||
|
id: str,
|
||||||
|
message: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""与LLM模型对话"""
|
||||||
|
model = db.query(LLMModel).filter(LLMModel.id == id).first()
|
||||||
|
if not model:
|
||||||
|
raise HTTPException(status_code=404, detail="LLM Model not found")
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
messages.append({"role": "user", "content": message})
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": model.model_name or "gpt-3.5-turbo",
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": max_tokens or 1000,
|
||||||
|
"temperature": temperature if temperature is not None else model.temperature or 0.7,
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
response = client.post(
|
||||||
|
f"{model.base_url}/chat/completions",
|
||||||
|
json=payload,
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
if choice := result.get("choices", [{}])[0]:
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"reply": choice.get("message", {}).get("content", ""),
|
||||||
|
"usage": result.get("usage", {})
|
||||||
|
}
|
||||||
|
return {"success": False, "reply": "", "error": "No response"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
379
api/app/routers/tools.py
Normal file
379
api/app/routers/tools.py
Normal file
@@ -0,0 +1,379 @@
|
|||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from ..db import get_db
|
||||||
|
from ..models import LLMModel, ASRModel
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/tools", tags=["Tools & Autotest"])
|
||||||
|
|
||||||
|
|
||||||
|
# ============ Available Tools ============
|
||||||
|
TOOL_REGISTRY = {
|
||||||
|
"search": {
|
||||||
|
"name": "网络搜索",
|
||||||
|
"description": "搜索互联网获取最新信息",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string", "description": "搜索关键词"}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"calculator": {
|
||||||
|
"name": "计算器",
|
||||||
|
"description": "执行数学计算",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"expression": {"type": "string", "description": "数学表达式,如: 2 + 3 * 4"}
|
||||||
|
},
|
||||||
|
"required": ["expression"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"weather": {
|
||||||
|
"name": "天气查询",
|
||||||
|
"description": "查询指定城市的天气",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {"type": "string", "description": "城市名称"}
|
||||||
|
},
|
||||||
|
"required": ["city"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"translate": {
|
||||||
|
"name": "翻译",
|
||||||
|
"description": "翻译文本到指定语言",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"text": {"type": "string", "description": "要翻译的文本"},
|
||||||
|
"target_lang": {"type": "string", "description": "目标语言,如: en, ja, ko"}
|
||||||
|
},
|
||||||
|
"required": ["text", "target_lang"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"knowledge": {
|
||||||
|
"name": "知识库查询",
|
||||||
|
"description": "从知识库中检索相关信息",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string", "description": "查询内容"},
|
||||||
|
"kb_id": {"type": "string", "description": "知识库ID"}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"code_interpreter": {
|
||||||
|
"name": "代码执行",
|
||||||
|
"description": "安全地执行Python代码",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"code": {"type": "string", "description": "要执行的Python代码"}
|
||||||
|
},
|
||||||
|
"required": ["code"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/list")
|
||||||
|
def list_available_tools():
|
||||||
|
"""获取可用的工具列表"""
|
||||||
|
return {"tools": TOOL_REGISTRY}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/list/{tool_id}")
|
||||||
|
def get_tool_detail(tool_id: str):
|
||||||
|
"""获取工具详情"""
|
||||||
|
if tool_id not in TOOL_REGISTRY:
|
||||||
|
raise HTTPException(status_code=404, detail="Tool not found")
|
||||||
|
return TOOL_REGISTRY[tool_id]
|
||||||
|
|
||||||
|
|
||||||
|
# ============ Autotest ============
|
||||||
|
class AutotestResult:
|
||||||
|
"""自动测试结果"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.id = str(uuid.uuid4())[:8]
|
||||||
|
self.started_at = time.time()
|
||||||
|
self.tests = []
|
||||||
|
self.summary = {"passed": 0, "failed": 0, "total": 0}
|
||||||
|
|
||||||
|
def add_test(self, name: str, passed: bool, message: str = "", duration_ms: int = 0):
|
||||||
|
self.tests.append({
|
||||||
|
"name": name,
|
||||||
|
"passed": passed,
|
||||||
|
"message": message,
|
||||||
|
"duration_ms": duration_ms
|
||||||
|
})
|
||||||
|
if passed:
|
||||||
|
self.summary["passed"] += 1
|
||||||
|
else:
|
||||||
|
self.summary["failed"] += 1
|
||||||
|
self.summary["total"] += 1
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"started_at": self.started_at,
|
||||||
|
"duration_ms": int((time.time() - self.started_at) * 1000),
|
||||||
|
"tests": self.tests,
|
||||||
|
"summary": self.summary
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/autotest")
|
||||||
|
def run_autotest(
|
||||||
|
llm_model_id: Optional[str] = None,
|
||||||
|
asr_model_id: Optional[str] = None,
|
||||||
|
test_llm: bool = True,
|
||||||
|
test_asr: bool = True,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""运行自动测试"""
|
||||||
|
result = AutotestResult()
|
||||||
|
|
||||||
|
# 测试 LLM 模型
|
||||||
|
if test_llm and llm_model_id:
|
||||||
|
_test_llm_model(db, llm_model_id, result)
|
||||||
|
|
||||||
|
# 测试 ASR 模型
|
||||||
|
if test_asr and asr_model_id:
|
||||||
|
_test_asr_model(db, asr_model_id, result)
|
||||||
|
|
||||||
|
# 测试 TTS 功能(需要时可添加)
|
||||||
|
if test_llm and not llm_model_id:
|
||||||
|
result.add_test(
|
||||||
|
"LLM Model Check",
|
||||||
|
False,
|
||||||
|
"No LLM model ID provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
if test_asr and not asr_model_id:
|
||||||
|
result.add_test(
|
||||||
|
"ASR Model Check",
|
||||||
|
False,
|
||||||
|
"No ASR model ID provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result.to_dict()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/autotest/llm/{model_id}")
|
||||||
|
def autotest_llm_model(model_id: str, db: Session = Depends(get_db)):
|
||||||
|
"""测试单个LLM模型"""
|
||||||
|
result = AutotestResult()
|
||||||
|
_test_llm_model(db, model_id, result)
|
||||||
|
return result.to_dict()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/autotest/asr/{model_id}")
|
||||||
|
def autotest_asr_model(model_id: str, db: Session = Depends(get_db)):
|
||||||
|
"""测试单个ASR模型"""
|
||||||
|
result = AutotestResult()
|
||||||
|
_test_asr_model(db, model_id, result)
|
||||||
|
return result.to_dict()
|
||||||
|
|
||||||
|
|
||||||
|
def _test_llm_model(db: Session, model_id: str, result: AutotestResult):
|
||||||
|
"""内部方法:测试LLM模型"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 1. 检查模型是否存在
|
||||||
|
model = db.query(LLMModel).filter(LLMModel.id == model_id).first()
|
||||||
|
duration_ms = int((time.time() - start_time) * 1000)
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
result.add_test("Model Existence", False, f"Model {model_id} not found", duration_ms)
|
||||||
|
return
|
||||||
|
|
||||||
|
result.add_test("Model Existence", True, f"Found model: {model.name}", duration_ms)
|
||||||
|
|
||||||
|
# 2. 测试连接
|
||||||
|
test_start = time.time()
|
||||||
|
try:
|
||||||
|
test_messages = [{"role": "user", "content": "Reply with 'OK'."}]
|
||||||
|
payload = {
|
||||||
|
"model": model.model_name or "gpt-3.5-turbo",
|
||||||
|
"messages": test_messages,
|
||||||
|
"max_tokens": 10,
|
||||||
|
"temperature": 0.1,
|
||||||
|
}
|
||||||
|
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||||
|
|
||||||
|
with httpx.Client(timeout=30.0) as client:
|
||||||
|
response = client.post(
|
||||||
|
f"{model.base_url}/chat/completions",
|
||||||
|
json=payload,
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
result_text = response.json()
|
||||||
|
latency_ms = int((time.time() - test_start) * 1000)
|
||||||
|
|
||||||
|
if result_text.get("choices"):
|
||||||
|
result.add_test("API Connection", True, f"Latency: {latency_ms}ms", latency_ms)
|
||||||
|
else:
|
||||||
|
result.add_test("API Connection", False, "Empty response", latency_ms)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
latency_ms = int((time.time() - test_start) * 1000)
|
||||||
|
result.add_test("API Connection", False, str(e)[:200], latency_ms)
|
||||||
|
|
||||||
|
# 3. 检查模型配置
|
||||||
|
if model.temperature is not None:
|
||||||
|
result.add_test("Temperature Setting", True, f"temperature={model.temperature}")
|
||||||
|
else:
|
||||||
|
result.add_test("Temperature Setting", True, "Using default")
|
||||||
|
|
||||||
|
# 4. 测试流式响应(可选)
|
||||||
|
if model.type == "text":
|
||||||
|
test_start = time.time()
|
||||||
|
try:
|
||||||
|
with httpx.Client(timeout=30.0) as client:
|
||||||
|
with client.stream(
|
||||||
|
"POST",
|
||||||
|
f"{model.base_url}/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": model.model_name or "gpt-3.5-turbo",
|
||||||
|
"messages": [{"role": "user", "content": "Count from 1 to 3."}],
|
||||||
|
"stream": True,
|
||||||
|
},
|
||||||
|
headers=headers
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
chunk_count = 0
|
||||||
|
for _ in response.iter_bytes():
|
||||||
|
chunk_count += 1
|
||||||
|
|
||||||
|
latency_ms = int((time.time() - test_start) * 1000)
|
||||||
|
result.add_test("Streaming Support", True, f"Received {chunk_count} chunks", latency_ms)
|
||||||
|
except Exception as e:
|
||||||
|
latency_ms = int((time.time() - test_start) * 1000)
|
||||||
|
result.add_test("Streaming Support", False, str(e)[:200], latency_ms)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_asr_model(db: Session, model_id: str, result: AutotestResult):
|
||||||
|
"""内部方法:测试ASR模型"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 1. 检查模型是否存在
|
||||||
|
model = db.query(ASRModel).filter(ASRModel.id == model_id).first()
|
||||||
|
duration_ms = int((time.time() - start_time) * 1000)
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
result.add_test("Model Existence", False, f"Model {model_id} not found", duration_ms)
|
||||||
|
return
|
||||||
|
|
||||||
|
result.add_test("Model Existence", True, f"Found model: {model.name}", duration_ms)
|
||||||
|
|
||||||
|
# 2. 测试配置
|
||||||
|
if model.hotwords:
|
||||||
|
result.add_test("Hotwords Config", True, f"Hotwords: {len(model.hotwords)} words")
|
||||||
|
else:
|
||||||
|
result.add_test("Hotwords Config", True, "No hotwords configured")
|
||||||
|
|
||||||
|
# 3. 测试API可用性
|
||||||
|
test_start = time.time()
|
||||||
|
try:
|
||||||
|
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||||
|
|
||||||
|
with httpx.Client(timeout=30.0) as client:
|
||||||
|
if model.vendor.lower() in ["siliconflow", "paraformer"]:
|
||||||
|
response = client.get(
|
||||||
|
f"{model.base_url}/asr",
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
elif model.vendor.lower() == "openai":
|
||||||
|
response = client.get(
|
||||||
|
f"{model.base_url}/audio/models",
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 通用健康检查
|
||||||
|
response = client.get(
|
||||||
|
f"{model.base_url}/health",
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
latency_ms = int((time.time() - test_start) * 1000)
|
||||||
|
|
||||||
|
if response.status_code in [200, 405]: # 405 = method not allowed but endpoint exists
|
||||||
|
result.add_test("API Availability", True, f"Status: {response.status_code}", latency_ms)
|
||||||
|
else:
|
||||||
|
result.add_test("API Availability", False, f"Status: {response.status_code}", latency_ms)
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
latency_ms = int((time.time() - test_start) * 1000)
|
||||||
|
result.add_test("API Availability", False, "Connection timeout", latency_ms)
|
||||||
|
except Exception as e:
|
||||||
|
latency_ms = int((time.time() - test_start) * 1000)
|
||||||
|
result.add_test("API Availability", False, str(e)[:200], latency_ms)
|
||||||
|
|
||||||
|
# 4. 检查语言配置
|
||||||
|
if model.language in ["zh", "en", "Multi-lingual"]:
|
||||||
|
result.add_test("Language Config", True, f"Language: {model.language}")
|
||||||
|
else:
|
||||||
|
result.add_test("Language Config", False, f"Unknown language: {model.language}")
|
||||||
|
|
||||||
|
|
||||||
|
# ============ Quick Health Check ============
|
||||||
|
@router.get("/health")
|
||||||
|
def health_check():
|
||||||
|
"""快速健康检查"""
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"tools": list(TOOL_REGISTRY.keys())
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/test-message")
|
||||||
|
def send_test_message(
|
||||||
|
llm_model_id: str,
|
||||||
|
message: str = "Hello, this is a test message.",
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""发送测试消息"""
|
||||||
|
model = db.query(LLMModel).filter(LLMModel.id == llm_model_id).first()
|
||||||
|
if not model:
|
||||||
|
raise HTTPException(status_code=404, detail="LLM Model not found")
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = {
|
||||||
|
"model": model.model_name or "gpt-3.5-turbo",
|
||||||
|
"messages": [{"role": "user", "content": message}],
|
||||||
|
"max_tokens": 500,
|
||||||
|
"temperature": 0.7,
|
||||||
|
}
|
||||||
|
headers = {"Authorization": f"Bearer {model.api_key}"}
|
||||||
|
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
response = client.post(
|
||||||
|
f"{model.base_url}/chat/completions",
|
||||||
|
json=payload,
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
reply = result.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"reply": reply,
|
||||||
|
"usage": result.get("usage", {})
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
@@ -1,24 +1,203 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
# ============ Enums ============
|
||||||
|
class AssistantConfigMode(str, Enum):
|
||||||
|
PLATFORM = "platform"
|
||||||
|
DIFY = "dify"
|
||||||
|
FASTGPT = "fastgpt"
|
||||||
|
NONE = "none"
|
||||||
|
|
||||||
|
|
||||||
|
class LLMModelType(str, Enum):
|
||||||
|
TEXT = "text"
|
||||||
|
EMBEDDING = "embedding"
|
||||||
|
RERANK = "rerank"
|
||||||
|
|
||||||
|
|
||||||
|
class ASRLanguage(str, Enum):
|
||||||
|
ZH = "zh"
|
||||||
|
EN = "en"
|
||||||
|
MULTILINGUAL = "Multi-lingual"
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceGender(str, Enum):
|
||||||
|
MALE = "Male"
|
||||||
|
FEMALE = "Female"
|
||||||
|
|
||||||
|
|
||||||
|
class CallRecordSource(str, Enum):
|
||||||
|
DEBUG = "debug"
|
||||||
|
EXTERNAL = "external"
|
||||||
|
|
||||||
|
|
||||||
|
class CallRecordStatus(str, Enum):
|
||||||
|
CONNECTED = "connected"
|
||||||
|
MISSED = "missed"
|
||||||
|
FAILED = "failed"
|
||||||
|
|
||||||
|
|
||||||
# ============ Voice ============
|
# ============ Voice ============
|
||||||
class VoiceBase(BaseModel):
|
class VoiceBase(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
vendor: str
|
vendor: str
|
||||||
gender: str
|
gender: str # "Male" | "Female"
|
||||||
language: str
|
language: str # "zh" | "en"
|
||||||
description: str
|
description: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceCreate(VoiceBase):
|
||||||
|
model: str # 厂商语音模型标识
|
||||||
|
voice_key: str # 厂商voice_key
|
||||||
|
speed: float = 1.0
|
||||||
|
gain: int = 0
|
||||||
|
pitch: int = 0
|
||||||
|
enabled: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceUpdate(BaseModel):
|
||||||
|
name: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
model: Optional[str] = None
|
||||||
|
voice_key: Optional[str] = None
|
||||||
|
speed: Optional[float] = None
|
||||||
|
gain: Optional[int] = None
|
||||||
|
pitch: Optional[int] = None
|
||||||
|
enabled: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
class VoiceOut(VoiceBase):
|
class VoiceOut(VoiceBase):
|
||||||
id: str
|
id: str
|
||||||
|
user_id: Optional[int] = None
|
||||||
|
model: Optional[str] = None
|
||||||
|
voice_key: Optional[str] = None
|
||||||
|
speed: float = 1.0
|
||||||
|
gain: int = 0
|
||||||
|
pitch: int = 0
|
||||||
|
enabled: bool = True
|
||||||
|
is_system: bool = False
|
||||||
|
created_at: Optional[datetime] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
class VoicePreviewRequest(BaseModel):
|
||||||
|
text: str
|
||||||
|
speed: Optional[float] = None
|
||||||
|
gain: Optional[int] = None
|
||||||
|
pitch: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class VoicePreviewResponse(BaseModel):
|
||||||
|
success: bool
|
||||||
|
audio_url: Optional[str] = None
|
||||||
|
duration_ms: Optional[int] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
# ============ LLM Model ============
|
||||||
|
class LLMModelBase(BaseModel):
|
||||||
|
name: str
|
||||||
|
vendor: str
|
||||||
|
type: LLMModelType
|
||||||
|
base_url: str
|
||||||
|
api_key: str
|
||||||
|
model_name: Optional[str] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
context_length: Optional[int] = None
|
||||||
|
enabled: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class LLMModelCreate(LLMModelBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LLMModelUpdate(BaseModel):
|
||||||
|
name: Optional[str] = None
|
||||||
|
base_url: Optional[str] = None
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
model_name: Optional[str] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
context_length: Optional[int] = None
|
||||||
|
enabled: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
|
class LLMModelOut(LLMModelBase):
|
||||||
|
id: str
|
||||||
|
user_id: int
|
||||||
|
created_at: Optional[datetime] = None
|
||||||
|
updated_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
class LLMModelTestResponse(BaseModel):
|
||||||
|
success: bool
|
||||||
|
latency_ms: Optional[int] = None
|
||||||
|
message: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
# ============ ASR Model ============
|
||||||
|
class ASRModelBase(BaseModel):
|
||||||
|
name: str
|
||||||
|
vendor: str
|
||||||
|
language: str # "zh" | "en" | "Multi-lingual"
|
||||||
|
base_url: str
|
||||||
|
api_key: str
|
||||||
|
model_name: Optional[str] = None
|
||||||
|
enabled: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class ASRModelCreate(ASRModelBase):
|
||||||
|
hotwords: List[str] = []
|
||||||
|
enable_punctuation: bool = True
|
||||||
|
enable_normalization: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class ASRModelUpdate(BaseModel):
|
||||||
|
name: Optional[str] = None
|
||||||
|
language: Optional[str] = None
|
||||||
|
base_url: Optional[str] = None
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
model_name: Optional[str] = None
|
||||||
|
hotwords: Optional[List[str]] = None
|
||||||
|
enable_punctuation: Optional[bool] = None
|
||||||
|
enable_normalization: Optional[bool] = None
|
||||||
|
enabled: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ASRModelOut(ASRModelBase):
|
||||||
|
id: str
|
||||||
|
user_id: int
|
||||||
|
hotwords: List[str] = []
|
||||||
|
enable_punctuation: bool = True
|
||||||
|
enable_normalization: bool = True
|
||||||
|
created_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
class ASRTestRequest(BaseModel):
|
||||||
|
audio_url: Optional[str] = None
|
||||||
|
audio_data: Optional[str] = None # base64 encoded
|
||||||
|
|
||||||
|
|
||||||
|
class ASRTestResponse(BaseModel):
|
||||||
|
success: bool
|
||||||
|
transcript: Optional[str] = None
|
||||||
|
language: Optional[str] = None
|
||||||
|
confidence: Optional[float] = None
|
||||||
|
duration_ms: Optional[int] = None
|
||||||
|
latency_ms: Optional[int] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
# ============ Assistant ============
|
# ============ Assistant ============
|
||||||
class AssistantBase(BaseModel):
|
class AssistantBase(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
@@ -34,25 +213,56 @@ class AssistantBase(BaseModel):
|
|||||||
configMode: str = "platform"
|
configMode: str = "platform"
|
||||||
apiUrl: Optional[str] = None
|
apiUrl: Optional[str] = None
|
||||||
apiKey: Optional[str] = None
|
apiKey: Optional[str] = None
|
||||||
|
# 模型关联
|
||||||
|
llmModelId: Optional[str] = None
|
||||||
|
asrModelId: Optional[str] = None
|
||||||
|
embeddingModelId: Optional[str] = None
|
||||||
|
rerankModelId: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class AssistantCreate(AssistantBase):
|
class AssistantCreate(AssistantBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class AssistantUpdate(AssistantBase):
|
class AssistantUpdate(BaseModel):
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
|
opener: Optional[str] = None
|
||||||
|
prompt: Optional[str] = None
|
||||||
|
knowledgeBaseId: Optional[str] = None
|
||||||
|
language: Optional[str] = None
|
||||||
|
voice: Optional[str] = None
|
||||||
|
speed: Optional[float] = None
|
||||||
|
hotwords: Optional[List[str]] = None
|
||||||
|
tools: Optional[List[str]] = None
|
||||||
|
interruptionSensitivity: Optional[int] = None
|
||||||
|
configMode: Optional[str] = None
|
||||||
|
apiUrl: Optional[str] = None
|
||||||
|
apiKey: Optional[str] = None
|
||||||
|
llmModelId: Optional[str] = None
|
||||||
|
asrModelId: Optional[str] = None
|
||||||
|
embeddingModelId: Optional[str] = None
|
||||||
|
rerankModelId: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class AssistantOut(AssistantBase):
|
class AssistantOut(AssistantBase):
|
||||||
id: str
|
id: str
|
||||||
callCount: int = 0
|
callCount: int = 0
|
||||||
created_at: Optional[datetime] = None
|
created_at: Optional[datetime] = None
|
||||||
|
updated_at: Optional[datetime] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantStats(BaseModel):
|
||||||
|
assistant_id: str
|
||||||
|
total_calls: int = 0
|
||||||
|
connected_calls: int = 0
|
||||||
|
missed_calls: int = 0
|
||||||
|
avg_duration_seconds: float = 0.0
|
||||||
|
today_calls: int = 0
|
||||||
|
|
||||||
|
|
||||||
# ============ Knowledge Base ============
|
# ============ Knowledge Base ============
|
||||||
class KnowledgeDocument(BaseModel):
|
class KnowledgeDocument(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
@@ -196,6 +406,7 @@ class TranscriptSegment(BaseModel):
|
|||||||
endMs: int
|
endMs: int
|
||||||
durationMs: Optional[int] = None
|
durationMs: Optional[int] = None
|
||||||
audioUrl: Optional[str] = None
|
audioUrl: Optional[str] = None
|
||||||
|
emotion: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class CallRecordCreate(BaseModel):
|
class CallRecordCreate(BaseModel):
|
||||||
@@ -208,6 +419,9 @@ class CallRecordUpdate(BaseModel):
|
|||||||
status: Optional[str] = None
|
status: Optional[str] = None
|
||||||
summary: Optional[str] = None
|
summary: Optional[str] = None
|
||||||
duration_seconds: Optional[int] = None
|
duration_seconds: Optional[int] = None
|
||||||
|
ended_at: Optional[str] = None
|
||||||
|
cost: Optional[float] = None
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
class CallRecordOut(BaseModel):
|
class CallRecordOut(BaseModel):
|
||||||
@@ -220,6 +434,9 @@ class CallRecordOut(BaseModel):
|
|||||||
ended_at: Optional[str] = None
|
ended_at: Optional[str] = None
|
||||||
duration_seconds: Optional[int] = None
|
duration_seconds: Optional[int] = None
|
||||||
summary: Optional[str] = None
|
summary: Optional[str] = None
|
||||||
|
cost: float = 0.0
|
||||||
|
metadata: dict = {}
|
||||||
|
created_at: Optional[datetime] = None
|
||||||
transcripts: List[TranscriptSegment] = []
|
transcripts: List[TranscriptSegment] = []
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@@ -246,6 +463,19 @@ class TranscriptOut(TranscriptCreate):
|
|||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
# ============ History Stats ============
|
||||||
|
class HistoryStats(BaseModel):
|
||||||
|
total_calls: int = 0
|
||||||
|
connected_calls: int = 0
|
||||||
|
missed_calls: int = 0
|
||||||
|
failed_calls: int = 0
|
||||||
|
avg_duration_seconds: float = 0.0
|
||||||
|
total_cost: float = 0.0
|
||||||
|
by_status: dict = {}
|
||||||
|
by_source: dict = {}
|
||||||
|
daily_trend: List[dict] = []
|
||||||
|
|
||||||
|
|
||||||
# ============ Dashboard ============
|
# ============ Dashboard ============
|
||||||
class DashboardStats(BaseModel):
|
class DashboardStats(BaseModel):
|
||||||
totalCalls: int
|
totalCalls: int
|
||||||
@@ -269,3 +499,9 @@ class ListResponse(BaseModel):
|
|||||||
page: int
|
page: int
|
||||||
limit: int
|
limit: int
|
||||||
list: List
|
list: List
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResult(BaseModel):
|
||||||
|
id: str
|
||||||
|
started_at: str
|
||||||
|
matched_content: Optional[str] = None
|
||||||
|
|||||||
409
api/docs/asr.md
Normal file
409
api/docs/asr.md
Normal file
@@ -0,0 +1,409 @@
|
|||||||
|
# 语音识别 (ASR Model) API
|
||||||
|
|
||||||
|
语音识别 API 用于管理语音识别模型的配置和调用。
|
||||||
|
|
||||||
|
## 基础信息
|
||||||
|
|
||||||
|
| 项目 | 值 |
|
||||||
|
|------|-----|
|
||||||
|
| Base URL | `/api/v1/asr` |
|
||||||
|
| 认证方式 | Bearer Token (预留) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 数据模型
|
||||||
|
|
||||||
|
### ASRModel
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
interface ASRModel {
|
||||||
|
id: string; // 模型唯一标识 (8位UUID)
|
||||||
|
user_id: number; // 所属用户ID
|
||||||
|
name: string; // 模型显示名称
|
||||||
|
vendor: string; // 供应商: "OpenAI" | "SiliconFlow" | "Paraformer" | 等
|
||||||
|
language: string; // 识别语言: "zh" | "en" | "Multi-lingual"
|
||||||
|
base_url: string; // API Base URL
|
||||||
|
api_key: string; // API Key
|
||||||
|
model_name?: string; // 模型名称,如 "whisper-1" | "paraformer-v2"
|
||||||
|
hotwords?: string[]; // 热词列表
|
||||||
|
enable_punctuation: boolean; // 是否启用标点
|
||||||
|
enable_normalization: boolean; // 是否启用文本规范化
|
||||||
|
enabled: boolean; // 是否启用
|
||||||
|
created_at: string;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## API 端点
|
||||||
|
|
||||||
|
### 1. 获取 ASR 模型列表
|
||||||
|
|
||||||
|
```http
|
||||||
|
GET /api/v1/asr
|
||||||
|
```
|
||||||
|
|
||||||
|
**Query Parameters:**
|
||||||
|
|
||||||
|
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|
||||||
|
|------|------|------|--------|------|
|
||||||
|
| language | string | 否 | - | 过滤语言: "zh" \| "en" \| "Multi-lingual" |
|
||||||
|
| enabled | boolean | 否 | - | 过滤启用状态 |
|
||||||
|
| page | int | 否 | 1 | 页码 |
|
||||||
|
| limit | int | 否 | 50 | 每页数量 |
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"total": 3,
|
||||||
|
"page": 1,
|
||||||
|
"limit": 50,
|
||||||
|
"list": [
|
||||||
|
{
|
||||||
|
"id": "abc12345",
|
||||||
|
"user_id": 1,
|
||||||
|
"name": "Whisper 多语种识别",
|
||||||
|
"vendor": "OpenAI",
|
||||||
|
"language": "Multi-lingual",
|
||||||
|
"base_url": "https://api.openai.com/v1",
|
||||||
|
"api_key": "sk-***",
|
||||||
|
"model_name": "whisper-1",
|
||||||
|
"enable_punctuation": true,
|
||||||
|
"enable_normalization": true,
|
||||||
|
"enabled": true,
|
||||||
|
"created_at": "2024-01-15T10:30:00Z"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "def67890",
|
||||||
|
"user_id": 1,
|
||||||
|
"name": "SenseVoice 中文识别",
|
||||||
|
"vendor": "SiliconFlow",
|
||||||
|
"language": "zh",
|
||||||
|
"base_url": "https://api.siliconflow.cn/v1",
|
||||||
|
"api_key": "sf-***",
|
||||||
|
"model_name": "paraformer-v2",
|
||||||
|
"hotwords": ["小助手", "帮我"],
|
||||||
|
"enable_punctuation": true,
|
||||||
|
"enable_normalization": true,
|
||||||
|
"enabled": true,
|
||||||
|
"created_at": "2024-01-15T10:30:00Z"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2. 获取单个 ASR 模型详情
|
||||||
|
|
||||||
|
```http
|
||||||
|
GET /api/v1/asr/{id}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Path Parameters:**
|
||||||
|
|
||||||
|
| 参数 | 类型 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| id | string | 模型ID |
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "abc12345",
|
||||||
|
"user_id": 1,
|
||||||
|
"name": "Whisper 多语种识别",
|
||||||
|
"vendor": "OpenAI",
|
||||||
|
"language": "Multi-lingual",
|
||||||
|
"base_url": "https://api.openai.com/v1",
|
||||||
|
"api_key": "sk-***",
|
||||||
|
"model_name": "whisper-1",
|
||||||
|
"hotwords": [],
|
||||||
|
"enable_punctuation": true,
|
||||||
|
"enable_normalization": true,
|
||||||
|
"enabled": true,
|
||||||
|
"created_at": "2024-01-15T10:30:00Z"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3. 创建 ASR 模型
|
||||||
|
|
||||||
|
```http
|
||||||
|
POST /api/v1/asr
|
||||||
|
```
|
||||||
|
|
||||||
|
**Request Body:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "SenseVoice 中文识别",
|
||||||
|
"vendor": "SiliconFlow",
|
||||||
|
"language": "zh",
|
||||||
|
"base_url": "https://api.siliconflow.cn/v1",
|
||||||
|
"api_key": "sk-your-api-key",
|
||||||
|
"model_name": "paraformer-v2",
|
||||||
|
"hotwords": ["小助手", "帮我"],
|
||||||
|
"enable_punctuation": true,
|
||||||
|
"enable_normalization": true,
|
||||||
|
"enabled": true
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Fields 说明:**
|
||||||
|
|
||||||
|
| 字段 | 类型 | 必填 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| name | string | 是 | 模型显示名称 |
|
||||||
|
| vendor | string | 是 | 供应商: "OpenAI" / "SiliconFlow" / "Paraformer" |
|
||||||
|
| language | string | 是 | 语言: "zh" / "en" / "Multi-lingual" |
|
||||||
|
| base_url | string | 是 | API Base URL |
|
||||||
|
| api_key | string | 是 | API Key |
|
||||||
|
| model_name | string | 否 | 模型名称 |
|
||||||
|
| hotwords | string[] | 否 | 热词列表,提升识别准确率 |
|
||||||
|
| enable_punctuation | boolean | 否 | 是否输出标点,默认 true |
|
||||||
|
| enable_normalization | boolean | 否 | 是否文本规范化,默认 true |
|
||||||
|
| enabled | boolean | 否 | 是否启用,默认 true |
|
||||||
|
| id | string | 否 | 指定模型ID,默认自动生成 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 4. 更新 ASR 模型
|
||||||
|
|
||||||
|
```http
|
||||||
|
PUT /api/v1/asr/{id}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Request Body:** (部分更新)
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "Whisper-1 优化版",
|
||||||
|
"language": "zh",
|
||||||
|
"enable_punctuation": true,
|
||||||
|
"hotwords": ["新词1", "新词2"]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 5. 删除 ASR 模型
|
||||||
|
|
||||||
|
```http
|
||||||
|
DELETE /api/v1/asr/{id}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"message": "Deleted successfully"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 6. 测试 ASR 模型
|
||||||
|
|
||||||
|
```http
|
||||||
|
POST /api/v1/asr/{id}/test
|
||||||
|
```
|
||||||
|
|
||||||
|
**Request Body:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"audio_url": "https://example.com/test-audio.wav"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
或使用 Base64 编码的音频数据:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"audio_data": "UklGRi..."
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response (成功):**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"transcript": "您好,请问有什么可以帮助您?",
|
||||||
|
"language": "zh",
|
||||||
|
"confidence": 0.95,
|
||||||
|
"latency_ms": 500
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response (失败):**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": false,
|
||||||
|
"error": "HTTP Error: 401 - Unauthorized"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 7. 转写音频
|
||||||
|
|
||||||
|
```http
|
||||||
|
POST /api/v1/asr/{id}/transcribe
|
||||||
|
```
|
||||||
|
|
||||||
|
**Query Parameters:**
|
||||||
|
|
||||||
|
| 参数 | 类型 | 必填 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| audio_url | string | 否* | 音频文件URL |
|
||||||
|
| audio_data | string | 否* | Base64编码的音频数据 |
|
||||||
|
| hotwords | string[] | 否 | 热词列表 |
|
||||||
|
|
||||||
|
*二选一,至少提供一个
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"transcript": "您好,请问有什么可以帮助您?",
|
||||||
|
"language": "zh",
|
||||||
|
"confidence": 0.95
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Schema 定义
|
||||||
|
|
||||||
|
```python
|
||||||
|
from enum import Enum
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional, List
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
class ASRLanguage(str, Enum):
|
||||||
|
ZH = "zh"
|
||||||
|
EN = "en"
|
||||||
|
MULTILINGUAL = "Multi-lingual"
|
||||||
|
|
||||||
|
class ASRModelBase(BaseModel):
|
||||||
|
name: str
|
||||||
|
vendor: str
|
||||||
|
language: str # "zh" | "en" | "Multi-lingual"
|
||||||
|
base_url: str
|
||||||
|
api_key: str
|
||||||
|
model_name: Optional[str] = None
|
||||||
|
hotwords: List[str] = []
|
||||||
|
enable_punctuation: bool = True
|
||||||
|
enable_normalization: bool = True
|
||||||
|
enabled: bool = True
|
||||||
|
|
||||||
|
class ASRModelCreate(ASRModelBase):
|
||||||
|
id: Optional[str] = None
|
||||||
|
|
||||||
|
class ASRModelUpdate(BaseModel):
|
||||||
|
name: Optional[str] = None
|
||||||
|
language: Optional[str] = None
|
||||||
|
base_url: Optional[str] = None
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
model_name: Optional[str] = None
|
||||||
|
hotwords: Optional[List[str]] = None
|
||||||
|
enable_punctuation: Optional[bool] = None
|
||||||
|
enable_normalization: Optional[bool] = None
|
||||||
|
enabled: Optional[bool] = None
|
||||||
|
|
||||||
|
class ASRModelOut(ASRModelBase):
|
||||||
|
id: str
|
||||||
|
user_id: int
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
class ASRTestRequest(BaseModel):
|
||||||
|
audio_url: Optional[str] = None
|
||||||
|
audio_data: Optional[str] = None # base64 encoded
|
||||||
|
|
||||||
|
class ASRTestResponse(BaseModel):
|
||||||
|
success: bool
|
||||||
|
transcript: Optional[str] = None
|
||||||
|
language: Optional[str] = None
|
||||||
|
confidence: Optional[float] = None
|
||||||
|
latency_ms: Optional[int] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 供应商配置示例
|
||||||
|
|
||||||
|
### OpenAI Whisper
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"vendor": "OpenAI",
|
||||||
|
"base_url": "https://api.openai.com/v1",
|
||||||
|
"api_key": "sk-xxx",
|
||||||
|
"model_name": "whisper-1",
|
||||||
|
"language": "Multi-lingual",
|
||||||
|
"enable_punctuation": true,
|
||||||
|
"enable_normalization": true
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### SiliconFlow Paraformer
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"vendor": "SiliconFlow",
|
||||||
|
"base_url": "https://api.siliconflow.cn/v1",
|
||||||
|
"api_key": "sf-xxx",
|
||||||
|
"model_name": "paraformer-v2",
|
||||||
|
"language": "zh",
|
||||||
|
"hotwords": ["产品名称", "公司名"],
|
||||||
|
"enable_punctuation": true,
|
||||||
|
"enable_normalization": true
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 单元测试
|
||||||
|
|
||||||
|
项目包含完整的单元测试,位于 `api/tests/test_asr.py`。
|
||||||
|
|
||||||
|
### 测试用例概览
|
||||||
|
|
||||||
|
| 测试方法 | 说明 |
|
||||||
|
|----------|------|
|
||||||
|
| test_get_asr_models_empty | 空数据库获取测试 |
|
||||||
|
| test_create_asr_model | 创建模型测试 |
|
||||||
|
| test_create_asr_model_minimal | 最小数据创建测试 |
|
||||||
|
| test_get_asr_model_by_id | 获取单个模型测试 |
|
||||||
|
| test_get_asr_model_not_found | 获取不存在模型测试 |
|
||||||
|
| test_update_asr_model | 更新模型测试 |
|
||||||
|
| test_delete_asr_model | 删除模型测试 |
|
||||||
|
| test_list_asr_models_with_pagination | 分页测试 |
|
||||||
|
| test_filter_asr_models_by_language | 按语言过滤测试 |
|
||||||
|
| test_filter_asr_models_by_enabled | 按启用状态过滤测试 |
|
||||||
|
| test_create_asr_model_with_hotwords | 热词配置测试 |
|
||||||
|
| test_test_asr_model_siliconflow | SiliconFlow 供应商测试 |
|
||||||
|
| test_test_asr_model_openai | OpenAI 供应商测试 |
|
||||||
|
| test_different_asr_languages | 多语言测试 |
|
||||||
|
| test_different_asr_vendors | 多供应商测试 |
|
||||||
|
|
||||||
|
### 运行测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 运行 ASR 相关测试
|
||||||
|
pytest api/tests/test_asr.py -v
|
||||||
|
|
||||||
|
# 运行所有测试
|
||||||
|
pytest api/tests/ -v
|
||||||
|
```
|
||||||
@@ -7,9 +7,9 @@
|
|||||||
| 模块 | 文件 | 说明 |
|
| 模块 | 文件 | 说明 |
|
||||||
|------|------|------|
|
|------|------|------|
|
||||||
| 小助手 | [assistant.md](./assistant.md) | AI 助手管理 |
|
| 小助手 | [assistant.md](./assistant.md) | AI 助手管理 |
|
||||||
| 模型接入 | [model-access.md](./model-access.md) | LLM/ASR/TTS 模型配置 |
|
| LLM 模型 | [llm.md](./llm.md) | LLM 模型配置与管理 |
|
||||||
| 语音识别 | [speech-recognition.md](./speech-recognition.md) | ASR 模型配置 |
|
| ASR 模型 | [asr.md](./asr.md) | 语音识别模型配置 |
|
||||||
| 声音资源 | [voice-resources.md](./voice-resources.md) | TTS 声音库管理 |
|
| 工具与测试 | [tools.md](./tools.md) | 工具列表与自动测试 |
|
||||||
| 历史记录 | [history-records.md](./history-records.md) | 通话记录和转写 |
|
| 历史记录 | [history-records.md](./history-records.md) | 通话记录和转写 |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|||||||
401
api/docs/llm.md
Normal file
401
api/docs/llm.md
Normal file
@@ -0,0 +1,401 @@
|
|||||||
|
# LLM 模型 (LLM Model) API
|
||||||
|
|
||||||
|
LLM 模型 API 用于管理大语言模型的配置和调用。
|
||||||
|
|
||||||
|
## 基础信息
|
||||||
|
|
||||||
|
| 项目 | 值 |
|
||||||
|
|------|-----|
|
||||||
|
| Base URL | `/api/v1/llm` |
|
||||||
|
| 认证方式 | Bearer Token (预留) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 数据模型
|
||||||
|
|
||||||
|
### LLMModel
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
interface LLMModel {
|
||||||
|
id: string; // 模型唯一标识 (8位UUID)
|
||||||
|
user_id: number; // 所属用户ID
|
||||||
|
name: string; // 模型显示名称
|
||||||
|
vendor: string; // 供应商: "OpenAI" | "SiliconFlow" | "Dify" | "FastGPT" | 等
|
||||||
|
type: string; // 类型: "text" | "embedding" | "rerank"
|
||||||
|
base_url: string; // API Base URL
|
||||||
|
api_key: string; // API Key
|
||||||
|
model_name?: string; // 实际模型名称,如 "gpt-4o"
|
||||||
|
temperature?: number; // 温度参数 (0-2)
|
||||||
|
context_length?: int; // 上下文长度
|
||||||
|
enabled: boolean; // 是否启用
|
||||||
|
created_at: string;
|
||||||
|
updated_at: string;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## API 端点
|
||||||
|
|
||||||
|
### 1. 获取 LLM 模型列表
|
||||||
|
|
||||||
|
```http
|
||||||
|
GET /api/v1/llm
|
||||||
|
```
|
||||||
|
|
||||||
|
**Query Parameters:**
|
||||||
|
|
||||||
|
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|
||||||
|
|------|------|------|--------|------|
|
||||||
|
| model_type | string | 否 | - | 过滤类型: "text" \| "embedding" \| "rerank" |
|
||||||
|
| enabled | boolean | 否 | - | 过滤启用状态 |
|
||||||
|
| page | int | 否 | 1 | 页码 |
|
||||||
|
| limit | int | 否 | 50 | 每页数量 |
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"total": 5,
|
||||||
|
"page": 1,
|
||||||
|
"limit": 50,
|
||||||
|
"list": [
|
||||||
|
{
|
||||||
|
"id": "abc12345",
|
||||||
|
"user_id": 1,
|
||||||
|
"name": "GPT-4o",
|
||||||
|
"vendor": "OpenAI",
|
||||||
|
"type": "text",
|
||||||
|
"base_url": "https://api.openai.com/v1",
|
||||||
|
"api_key": "sk-***",
|
||||||
|
"model_name": "gpt-4o",
|
||||||
|
"temperature": 0.7,
|
||||||
|
"context_length": 128000,
|
||||||
|
"enabled": true,
|
||||||
|
"created_at": "2024-01-15T10:30:00Z",
|
||||||
|
"updated_at": "2024-01-15T10:30:00Z"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "def67890",
|
||||||
|
"user_id": 1,
|
||||||
|
"name": "Embedding-3-Small",
|
||||||
|
"vendor": "OpenAI",
|
||||||
|
"type": "embedding",
|
||||||
|
"base_url": "https://api.openai.com/v1",
|
||||||
|
"api_key": "sk-***",
|
||||||
|
"model_name": "text-embedding-3-small",
|
||||||
|
"enabled": true
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2. 获取单个 LLM 模型详情
|
||||||
|
|
||||||
|
```http
|
||||||
|
GET /api/v1/llm/{id}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Path Parameters:**
|
||||||
|
|
||||||
|
| 参数 | 类型 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| id | string | 模型ID |
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "abc12345",
|
||||||
|
"user_id": 1,
|
||||||
|
"name": "GPT-4o",
|
||||||
|
"vendor": "OpenAI",
|
||||||
|
"type": "text",
|
||||||
|
"base_url": "https://api.openai.com/v1",
|
||||||
|
"api_key": "sk-***",
|
||||||
|
"model_name": "gpt-4o",
|
||||||
|
"temperature": 0.7,
|
||||||
|
"context_length": 128000,
|
||||||
|
"enabled": true,
|
||||||
|
"created_at": "2024-01-15T10:30:00Z",
|
||||||
|
"updated_at": "2024-01-15T10:30:00Z"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3. 创建 LLM 模型
|
||||||
|
|
||||||
|
```http
|
||||||
|
POST /api/v1/llm
|
||||||
|
```
|
||||||
|
|
||||||
|
**Request Body:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "GPT-4o",
|
||||||
|
"vendor": "OpenAI",
|
||||||
|
"type": "text",
|
||||||
|
"base_url": "https://api.openai.com/v1",
|
||||||
|
"api_key": "sk-your-api-key",
|
||||||
|
"model_name": "gpt-4o",
|
||||||
|
"temperature": 0.7,
|
||||||
|
"context_length": 128000,
|
||||||
|
"enabled": true
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Fields 说明:**
|
||||||
|
|
||||||
|
| 字段 | 类型 | 必填 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| name | string | 是 | 模型显示名称 |
|
||||||
|
| vendor | string | 是 | 供应商名称 |
|
||||||
|
| type | string | 是 | 模型类型: "text" / "embedding" / "rerank" |
|
||||||
|
| base_url | string | 是 | API Base URL |
|
||||||
|
| api_key | string | 是 | API Key |
|
||||||
|
| model_name | string | 否 | 实际模型名称 |
|
||||||
|
| temperature | number | 否 | 温度参数,默认 0.7 |
|
||||||
|
| context_length | int | 否 | 上下文长度 |
|
||||||
|
| enabled | boolean | 否 | 是否启用,默认 true |
|
||||||
|
| id | string | 否 | 指定模型ID,默认自动生成 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 4. 更新 LLM 模型
|
||||||
|
|
||||||
|
```http
|
||||||
|
PUT /api/v1/llm/{id}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Request Body:** (部分更新)
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "GPT-4o-Updated",
|
||||||
|
"temperature": 0.8,
|
||||||
|
"enabled": false
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 5. 删除 LLM 模型
|
||||||
|
|
||||||
|
```http
|
||||||
|
DELETE /api/v1/llm/{id}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"message": "Deleted successfully"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 6. 测试 LLM 模型连接
|
||||||
|
|
||||||
|
```http
|
||||||
|
POST /api/v1/llm/{id}/test
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"latency_ms": 150,
|
||||||
|
"message": "Connection successful"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**错误响应:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": false,
|
||||||
|
"latency_ms": 200,
|
||||||
|
"message": "HTTP Error: 401 - Unauthorized"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 7. 与 LLM 模型对话
|
||||||
|
|
||||||
|
```http
|
||||||
|
POST /api/v1/llm/{id}/chat
|
||||||
|
```
|
||||||
|
|
||||||
|
**Query Parameters:**
|
||||||
|
|
||||||
|
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|
||||||
|
|------|------|------|--------|------|
|
||||||
|
| message | string | 是 | - | 用户消息 |
|
||||||
|
| system_prompt | string | 否 | - | 系统提示词 |
|
||||||
|
| max_tokens | int | 否 | 1000 | 最大生成token数 |
|
||||||
|
| temperature | number | 否 | 模型配置 | 温度参数 |
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"reply": "您好!有什么可以帮助您的?",
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 20,
|
||||||
|
"completion_tokens": 15,
|
||||||
|
"total_tokens": 35
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Schema 定义
|
||||||
|
|
||||||
|
```python
|
||||||
|
from enum import Enum
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
class LLMModelType(str, Enum):
|
||||||
|
TEXT = "text"
|
||||||
|
EMBEDDING = "embedding"
|
||||||
|
RERANK = "rerank"
|
||||||
|
|
||||||
|
class LLMModelBase(BaseModel):
|
||||||
|
name: str
|
||||||
|
vendor: str
|
||||||
|
type: LLMModelType
|
||||||
|
base_url: str
|
||||||
|
api_key: str
|
||||||
|
model_name: Optional[str] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
context_length: Optional[int] = None
|
||||||
|
enabled: bool = True
|
||||||
|
|
||||||
|
class LLMModelCreate(LLMModelBase):
|
||||||
|
id: Optional[str] = None
|
||||||
|
|
||||||
|
class LLMModelUpdate(BaseModel):
|
||||||
|
name: Optional[str] = None
|
||||||
|
vendor: Optional[str] = None
|
||||||
|
base_url: Optional[str] = None
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
model_name: Optional[str] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
context_length: Optional[int] = None
|
||||||
|
enabled: Optional[bool] = None
|
||||||
|
|
||||||
|
class LLMModelOut(LLMModelBase):
|
||||||
|
id: str
|
||||||
|
user_id: int
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
class LLMModelTestResponse(BaseModel):
|
||||||
|
success: bool
|
||||||
|
latency_ms: int
|
||||||
|
message: str
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 供应商配置示例
|
||||||
|
|
||||||
|
### OpenAI
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"vendor": "OpenAI",
|
||||||
|
"base_url": "https://api.openai.com/v1",
|
||||||
|
"api_key": "sk-xxx",
|
||||||
|
"model_name": "gpt-4o",
|
||||||
|
"type": "text",
|
||||||
|
"temperature": 0.7
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### SiliconFlow
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"vendor": "SiliconFlow",
|
||||||
|
"base_url": "https://api.siliconflow.com/v1",
|
||||||
|
"api_key": "sf-xxx",
|
||||||
|
"model_name": "deepseek-v3",
|
||||||
|
"type": "text",
|
||||||
|
"temperature": 0.7
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Dify
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"vendor": "Dify",
|
||||||
|
"base_url": "https://your-dify.domain.com/v1",
|
||||||
|
"api_key": "app-xxx",
|
||||||
|
"model_name": "gpt-4",
|
||||||
|
"type": "text"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Embedding 模型
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"vendor": "OpenAI",
|
||||||
|
"base_url": "https://api.openai.com/v1",
|
||||||
|
"api_key": "sk-xxx",
|
||||||
|
"model_name": "text-embedding-3-small",
|
||||||
|
"type": "embedding"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 单元测试
|
||||||
|
|
||||||
|
项目包含完整的单元测试,位于 `api/tests/test_llm.py`。
|
||||||
|
|
||||||
|
### 测试用例概览
|
||||||
|
|
||||||
|
| 测试方法 | 说明 |
|
||||||
|
|----------|------|
|
||||||
|
| test_get_llm_models_empty | 空数据库获取测试 |
|
||||||
|
| test_create_llm_model | 创建模型测试 |
|
||||||
|
| test_create_llm_model_minimal | 最小数据创建测试 |
|
||||||
|
| test_get_llm_model_by_id | 获取单个模型测试 |
|
||||||
|
| test_get_llm_model_not_found | 获取不存在模型测试 |
|
||||||
|
| test_update_llm_model | 更新模型测试 |
|
||||||
|
| test_delete_llm_model | 删除模型测试 |
|
||||||
|
| test_list_llm_models_with_pagination | 分页测试 |
|
||||||
|
| test_filter_llm_models_by_type | 按类型过滤测试 |
|
||||||
|
| test_filter_llm_models_by_enabled | 按启用状态过滤测试 |
|
||||||
|
| test_create_llm_model_with_all_fields | 全字段创建测试 |
|
||||||
|
| test_test_llm_model_success | 测试连接成功测试 |
|
||||||
|
| test_test_llm_model_failure | 测试连接失败测试 |
|
||||||
|
| test_different_llm_vendors | 多供应商测试 |
|
||||||
|
| test_embedding_llm_model | Embedding 模型测试 |
|
||||||
|
|
||||||
|
### 运行测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 运行 LLM 相关测试
|
||||||
|
pytest api/tests/test_llm.py -v
|
||||||
|
|
||||||
|
# 运行所有测试
|
||||||
|
pytest api/tests/ -v
|
||||||
|
```
|
||||||
445
api/docs/tools.md
Normal file
445
api/docs/tools.md
Normal file
@@ -0,0 +1,445 @@
|
|||||||
|
# 工具与自动测试 (Tools & Autotest) API
|
||||||
|
|
||||||
|
工具与自动测试 API 用于管理可用工具列表和自动测试功能。
|
||||||
|
|
||||||
|
## 基础信息
|
||||||
|
|
||||||
|
| 项目 | 值 |
|
||||||
|
|------|-----|
|
||||||
|
| Base URL | `/api/v1/tools` |
|
||||||
|
| 认证方式 | Bearer Token (预留) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 可用工具 (Tool Registry)
|
||||||
|
|
||||||
|
系统内置以下工具:
|
||||||
|
|
||||||
|
| 工具ID | 名称 | 说明 |
|
||||||
|
|--------|------|------|
|
||||||
|
| search | 网络搜索 | 搜索互联网获取最新信息 |
|
||||||
|
| calculator | 计算器 | 执行数学计算 |
|
||||||
|
| weather | 天气查询 | 查询指定城市的天气 |
|
||||||
|
| translate | 翻译 | 翻译文本到指定语言 |
|
||||||
|
| knowledge | 知识库查询 | 从知识库中检索相关信息 |
|
||||||
|
| code_interpreter | 代码执行 | 安全地执行Python代码 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## API 端点
|
||||||
|
|
||||||
|
### 1. 获取可用工具列表
|
||||||
|
|
||||||
|
```http
|
||||||
|
GET /api/v1/tools/list
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"tools": {
|
||||||
|
"search": {
|
||||||
|
"name": "网络搜索",
|
||||||
|
"description": "搜索互联网获取最新信息",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string", "description": "搜索关键词"}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"calculator": {
|
||||||
|
"name": "计算器",
|
||||||
|
"description": "执行数学计算",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"expression": {"type": "string", "description": "数学表达式,如: 2 + 3 * 4"}
|
||||||
|
},
|
||||||
|
"required": ["expression"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"weather": {
|
||||||
|
"name": "天气查询",
|
||||||
|
"description": "查询指定城市的天气",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {"type": "string", "description": "城市名称"}
|
||||||
|
},
|
||||||
|
"required": ["city"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"translate": {
|
||||||
|
"name": "翻译",
|
||||||
|
"description": "翻译文本到指定语言",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"text": {"type": "string", "description": "要翻译的文本"},
|
||||||
|
"target_lang": {"type": "string", "description": "目标语言,如: en, ja, ko"}
|
||||||
|
},
|
||||||
|
"required": ["text", "target_lang"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"knowledge": {
|
||||||
|
"name": "知识库查询",
|
||||||
|
"description": "从知识库中检索相关信息",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string", "description": "查询内容"},
|
||||||
|
"kb_id": {"type": "string", "description": "知识库ID"}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"code_interpreter": {
|
||||||
|
"name": "代码执行",
|
||||||
|
"description": "安全地执行Python代码",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"code": {"type": "string", "description": "要执行的Python代码"}
|
||||||
|
},
|
||||||
|
"required": ["code"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2. 获取工具详情
|
||||||
|
|
||||||
|
```http
|
||||||
|
GET /api/v1/tools/list/{tool_id}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Path Parameters:**
|
||||||
|
|
||||||
|
| 参数 | 类型 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| tool_id | string | 工具ID |
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "计算器",
|
||||||
|
"description": "执行数学计算",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"expression": {"type": "string", "description": "数学表达式,如: 2 + 3 * 4"}
|
||||||
|
},
|
||||||
|
"required": ["expression"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**错误响应 (工具不存在):**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"detail": "Tool not found"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3. 健康检查
|
||||||
|
|
||||||
|
```http
|
||||||
|
GET /api/v1/tools/health
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "healthy",
|
||||||
|
"timestamp": 1705315200.123,
|
||||||
|
"tools": ["search", "calculator", "weather", "translate", "knowledge", "code_interpreter"]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 自动测试 (Autotest)
|
||||||
|
|
||||||
|
### 4. 运行完整自动测试
|
||||||
|
|
||||||
|
```http
|
||||||
|
POST /api/v1/tools/autotest
|
||||||
|
```
|
||||||
|
|
||||||
|
**Query Parameters:**
|
||||||
|
|
||||||
|
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|
||||||
|
|------|------|------|--------|------|
|
||||||
|
| llm_model_id | string | 否 | - | LLM 模型ID |
|
||||||
|
| asr_model_id | string | 否 | - | ASR 模型ID |
|
||||||
|
| test_llm | boolean | 否 | true | 是否测试LLM |
|
||||||
|
| test_asr | boolean | 否 | true | 是否测试ASR |
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "abc12345",
|
||||||
|
"started_at": 1705315200.0,
|
||||||
|
"duration_ms": 2500,
|
||||||
|
"tests": [
|
||||||
|
{
|
||||||
|
"name": "Model Existence",
|
||||||
|
"passed": true,
|
||||||
|
"message": "Found model: GPT-4o",
|
||||||
|
"duration_ms": 15
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "API Connection",
|
||||||
|
"passed": true,
|
||||||
|
"message": "Latency: 150ms",
|
||||||
|
"duration_ms": 150
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Temperature Setting",
|
||||||
|
"passed": true,
|
||||||
|
"message": "temperature=0.7"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Streaming Support",
|
||||||
|
"passed": true,
|
||||||
|
"message": "Received 15 chunks",
|
||||||
|
"duration_ms": 800
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"summary": {
|
||||||
|
"passed": 4,
|
||||||
|
"failed": 0,
|
||||||
|
"total": 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 5. 测试单个 LLM 模型
|
||||||
|
|
||||||
|
```http
|
||||||
|
POST /api/v1/tools/autotest/llm/{model_id}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Path Parameters:**
|
||||||
|
|
||||||
|
| 参数 | 类型 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| model_id | string | LLM 模型ID |
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "llm_test_001",
|
||||||
|
"started_at": 1705315200.0,
|
||||||
|
"duration_ms": 1200,
|
||||||
|
"tests": [
|
||||||
|
{
|
||||||
|
"name": "Model Existence",
|
||||||
|
"passed": true,
|
||||||
|
"message": "Found model: GPT-4o",
|
||||||
|
"duration_ms": 10
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "API Connection",
|
||||||
|
"passed": true,
|
||||||
|
"message": "Latency: 180ms",
|
||||||
|
"duration_ms": 180
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Temperature Setting",
|
||||||
|
"passed": true,
|
||||||
|
"message": "temperature=0.7"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Streaming Support",
|
||||||
|
"passed": true,
|
||||||
|
"message": "Received 12 chunks",
|
||||||
|
"duration_ms": 650
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"summary": {
|
||||||
|
"passed": 4,
|
||||||
|
"failed": 0,
|
||||||
|
"total": 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 6. 测试单个 ASR 模型
|
||||||
|
|
||||||
|
```http
|
||||||
|
POST /api/v1/tools/autotest/asr/{model_id}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Path Parameters:**
|
||||||
|
|
||||||
|
| 参数 | 类型 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| model_id | string | ASR 模型ID |
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "asr_test_001",
|
||||||
|
"started_at": 1705315200.0,
|
||||||
|
"duration_ms": 800,
|
||||||
|
"tests": [
|
||||||
|
{
|
||||||
|
"name": "Model Existence",
|
||||||
|
"passed": true,
|
||||||
|
"message": "Found model: Whisper-1",
|
||||||
|
"duration_ms": 8
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Hotwords Config",
|
||||||
|
"passed": true,
|
||||||
|
"message": "Hotwords: 3 words"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "API Availability",
|
||||||
|
"passed": true,
|
||||||
|
"message": "Status: 200",
|
||||||
|
"duration_ms": 250
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Language Config",
|
||||||
|
"passed": true,
|
||||||
|
"message": "Language: zh"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"summary": {
|
||||||
|
"passed": 4,
|
||||||
|
"failed": 0,
|
||||||
|
"total": 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 7. 发送测试消息
|
||||||
|
|
||||||
|
```http
|
||||||
|
POST /api/v1/tools/test-message
|
||||||
|
```
|
||||||
|
|
||||||
|
**Query Parameters:**
|
||||||
|
|
||||||
|
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|
||||||
|
|------|------|------|--------|------|
|
||||||
|
| llm_model_id | string | 是 | - | LLM 模型ID |
|
||||||
|
| message | string | 否 | "Hello, this is a test message." | 测试消息 |
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"reply": "Hello! This is a test reply from GPT-4o.",
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 15,
|
||||||
|
"completion_tokens": 12,
|
||||||
|
"total_tokens": 27
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**错误响应 (模型不存在):**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"detail": "LLM Model not found"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 测试结果结构
|
||||||
|
|
||||||
|
### AutotestResult
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
interface AutotestResult {
|
||||||
|
id: string; // 测试ID
|
||||||
|
started_at: number; // 开始时间戳
|
||||||
|
duration_ms: number; // 总耗时(毫秒)
|
||||||
|
tests: TestCase[]; // 测试用例列表
|
||||||
|
summary: TestSummary; // 测试摘要
|
||||||
|
}
|
||||||
|
|
||||||
|
interface TestCase {
|
||||||
|
name: string; // 测试名称
|
||||||
|
passed: boolean; // 是否通过
|
||||||
|
message: string; // 测试消息
|
||||||
|
duration_ms: number; // 耗时(毫秒)
|
||||||
|
}
|
||||||
|
|
||||||
|
interface TestSummary {
|
||||||
|
passed: number; // 通过数量
|
||||||
|
failed: number; // 失败数量
|
||||||
|
total: number; // 总数量
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 测试项目说明
|
||||||
|
|
||||||
|
### LLM 模型测试项目
|
||||||
|
|
||||||
|
| 测试名称 | 说明 |
|
||||||
|
|----------|------|
|
||||||
|
| Model Existence | 检查模型是否存在于数据库 |
|
||||||
|
| API Connection | 测试 API 连接并测量延迟 |
|
||||||
|
| Temperature Setting | 检查温度配置 |
|
||||||
|
| Streaming Support | 测试流式响应支持 |
|
||||||
|
|
||||||
|
### ASR 模型测试项目
|
||||||
|
|
||||||
|
| 测试名称 | 说明 |
|
||||||
|
|----------|------|
|
||||||
|
| Model Existence | 检查模型是否存在于数据库 |
|
||||||
|
| Hotwords Config | 检查热词配置 |
|
||||||
|
| API Availability | 测试 API 可用性 |
|
||||||
|
| Language Config | 检查语言配置 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 单元测试
|
||||||
|
|
||||||
|
项目包含完整的单元测试,位于 `api/tests/test_tools.py`。
|
||||||
|
|
||||||
|
### 测试用例概览
|
||||||
|
|
||||||
|
| 测试类 | 说明 |
|
||||||
|
|--------|------|
|
||||||
|
| TestToolsAPI | 工具列表、健康检查等基础功能测试 |
|
||||||
|
| TestAutotestAPI | 自动测试功能完整测试 |
|
||||||
|
|
||||||
|
### 运行测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 运行工具相关测试
|
||||||
|
pytest api/tests/test_tools.py -v
|
||||||
|
|
||||||
|
# 运行所有测试
|
||||||
|
pytest api/tests/ -v
|
||||||
|
```
|
||||||
374
api/init_db.py
374
api/init_db.py
@@ -6,38 +6,381 @@ import sys
|
|||||||
# 添加路径
|
# 添加路径
|
||||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
from app.db import Base, engine
|
from app.db import Base, engine, DATABASE_URL
|
||||||
from app.models import Voice
|
from app.models import Voice, Assistant, KnowledgeBase, Workflow, LLMModel, ASRModel
|
||||||
|
|
||||||
|
|
||||||
def init_db():
|
def init_db():
|
||||||
"""创建所有表"""
|
"""创建所有表"""
|
||||||
|
# 确保 data 目录存在
|
||||||
|
data_dir = os.path.dirname(DATABASE_URL.replace("sqlite:///", ""))
|
||||||
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
|
|
||||||
print("📦 创建数据库表...")
|
print("📦 创建数据库表...")
|
||||||
Base.metadata.drop_all(bind=engine) # 删除旧表
|
Base.metadata.drop_all(bind=engine) # 删除旧表
|
||||||
Base.metadata.create_all(bind=engine)
|
Base.metadata.create_all(bind=engine)
|
||||||
print("✅ 数据库表创建完成")
|
print("✅ 数据库表创建完成")
|
||||||
|
|
||||||
|
|
||||||
def init_default_voices():
|
def init_default_data():
|
||||||
"""初始化默认声音"""
|
from sqlalchemy.orm import Session
|
||||||
from app.db import SessionLocal
|
from app.db import SessionLocal
|
||||||
|
from app.models import Voice
|
||||||
|
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
|
# 检查是否已有数据
|
||||||
if db.query(Voice).count() == 0:
|
if db.query(Voice).count() == 0:
|
||||||
|
# SiliconFlow CosyVoice 2.0 预设声音 (8个)
|
||||||
|
# 参考: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech
|
||||||
voices = [
|
voices = [
|
||||||
Voice(id="v1", name="Xiaoyun", vendor="Ali", gender="Female", language="zh", description="Gentle and professional."),
|
# 男声 (Male Voices)
|
||||||
Voice(id="v2", name="Kevin", vendor="Volcano", gender="Male", language="en", description="Deep and authoritative."),
|
Voice(id="alex", name="Alex", vendor="SiliconFlow", gender="Male", language="en",
|
||||||
Voice(id="v3", name="Abby", vendor="Minimax", gender="Female", language="en", description="Cheerful and lively."),
|
description="Steady male voice.", is_system=True),
|
||||||
Voice(id="v4", name="Guang", vendor="Ali", gender="Male", language="zh", description="Standard newscast style."),
|
Voice(id="benjamin", name="Benjamin", vendor="SiliconFlow", gender="Male", language="en",
|
||||||
Voice(id="v5", name="Doubao", vendor="Volcano", gender="Female", language="zh", description="Cute and young."),
|
description="Deep male voice.", is_system=True),
|
||||||
|
Voice(id="charles", name="Charles", vendor="SiliconFlow", gender="Male", language="en",
|
||||||
|
description="Magnetic male voice.", is_system=True),
|
||||||
|
Voice(id="david", name="David", vendor="SiliconFlow", gender="Male", language="en",
|
||||||
|
description="Cheerful male voice.", is_system=True),
|
||||||
|
# 女声 (Female Voices)
|
||||||
|
Voice(id="anna", name="Anna", vendor="SiliconFlow", gender="Female", language="en",
|
||||||
|
description="Steady female voice.", is_system=True),
|
||||||
|
Voice(id="bella", name="Bella", vendor="SiliconFlow", gender="Female", language="en",
|
||||||
|
description="Passionate female voice.", is_system=True),
|
||||||
|
Voice(id="claire", name="Claire", vendor="SiliconFlow", gender="Female", language="en",
|
||||||
|
description="Gentle female voice.", is_system=True),
|
||||||
|
Voice(id="diana", name="Diana", vendor="SiliconFlow", gender="Female", language="en",
|
||||||
|
description="Cheerful female voice.", is_system=True),
|
||||||
|
# 中文方言 (Chinese Dialects) - 可选扩展
|
||||||
|
Voice(id="amador", name="Amador", vendor="SiliconFlow", gender="Male", language="zh",
|
||||||
|
description="Male voice with Spanish accent."),
|
||||||
|
Voice(id="aelora", name="Aelora", vendor="SiliconFlow", gender="Female", language="en",
|
||||||
|
description="Elegant female voice."),
|
||||||
|
Voice(id="aelwin", name="Aelwin", vendor="SiliconFlow", gender="Male", language="en",
|
||||||
|
description="Deep male voice."),
|
||||||
|
Voice(id="blooming", name="Blooming", vendor="SiliconFlow", gender="Female", language="en",
|
||||||
|
description="Fresh and clear female voice."),
|
||||||
|
Voice(id="elysia", name="Elysia", vendor="SiliconFlow", gender="Female", language="en",
|
||||||
|
description="Smooth and silky female voice."),
|
||||||
|
Voice(id="leo", name="Leo", vendor="SiliconFlow", gender="Male", language="en",
|
||||||
|
description="Young male voice."),
|
||||||
|
Voice(id="lin", name="Lin", vendor="SiliconFlow", gender="Female", language="zh",
|
||||||
|
description="Standard Chinese female voice."),
|
||||||
|
Voice(id="rose", name="Rose", vendor="SiliconFlow", gender="Female", language="en",
|
||||||
|
description="Soft and gentle female voice."),
|
||||||
|
Voice(id="shao", name="Shao", vendor="SiliconFlow", gender="Male", language="zh",
|
||||||
|
description="Deep Chinese male voice."),
|
||||||
|
Voice(id="sky", name="Sky", vendor="SiliconFlow", gender="Male", language="en",
|
||||||
|
description="Clear and bright male voice."),
|
||||||
|
Voice(id="ael西山", name="Ael西山", vendor="SiliconFlow", gender="Female", language="zh",
|
||||||
|
description="Female voice with Chinese dialect."),
|
||||||
]
|
]
|
||||||
for v in voices:
|
for v in voices:
|
||||||
db.add(v)
|
db.add(v)
|
||||||
db.commit()
|
db.commit()
|
||||||
print("✅ 默认声音数据已初始化")
|
print("✅ 默认声音数据已初始化 (SiliconFlow CosyVoice 2.0)")
|
||||||
else:
|
finally:
|
||||||
print("ℹ️ 声音数据已存在,跳过初始化")
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def init_default_assistants():
|
||||||
|
"""初始化默认助手"""
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from app.db import SessionLocal
|
||||||
|
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
if db.query(Assistant).count() == 0:
|
||||||
|
assistants = [
|
||||||
|
Assistant(
|
||||||
|
id="default",
|
||||||
|
user_id=1,
|
||||||
|
name="AI 助手",
|
||||||
|
call_count=0,
|
||||||
|
opener="你好!我是AI助手,有什么可以帮你的吗?",
|
||||||
|
prompt="你是一个友好的AI助手,请用简洁清晰的语言回答用户的问题。",
|
||||||
|
language="zh",
|
||||||
|
voice="anna",
|
||||||
|
speed=1.0,
|
||||||
|
hotwords=[],
|
||||||
|
tools=["search", "calculator"],
|
||||||
|
interruption_sensitivity=500,
|
||||||
|
config_mode="platform",
|
||||||
|
llm_model_id="deepseek-chat",
|
||||||
|
asr_model_id="paraformer-v2",
|
||||||
|
),
|
||||||
|
Assistant(
|
||||||
|
id="customer_service",
|
||||||
|
user_id=1,
|
||||||
|
name="客服助手",
|
||||||
|
call_count=0,
|
||||||
|
opener="您好,欢迎致电客服中心,请问有什么可以帮您?",
|
||||||
|
prompt="你是一个专业的客服人员,耐心解答客户问题,提供优质的服务体验。",
|
||||||
|
language="zh",
|
||||||
|
voice="bella",
|
||||||
|
speed=1.0,
|
||||||
|
hotwords=["客服", "投诉", "咨询"],
|
||||||
|
tools=["search"],
|
||||||
|
interruption_sensitivity=600,
|
||||||
|
config_mode="platform",
|
||||||
|
),
|
||||||
|
Assistant(
|
||||||
|
id="english_tutor",
|
||||||
|
user_id=1,
|
||||||
|
name="英语导师",
|
||||||
|
call_count=0,
|
||||||
|
opener="Hello! I'm your English learning companion. How can I help you today?",
|
||||||
|
prompt="You are a friendly English tutor. Help users practice English conversation and explain grammar points clearly.",
|
||||||
|
language="en",
|
||||||
|
voice="alex",
|
||||||
|
speed=1.0,
|
||||||
|
hotwords=["grammar", "vocabulary", "practice"],
|
||||||
|
tools=[],
|
||||||
|
interruption_sensitivity=400,
|
||||||
|
config_mode="platform",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for a in assistants:
|
||||||
|
db.add(a)
|
||||||
|
db.commit()
|
||||||
|
print("✅ 默认助手数据已初始化")
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def init_default_workflows():
|
||||||
|
"""初始化默认工作流"""
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from app.db import SessionLocal
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
if db.query(Workflow).count() == 0:
|
||||||
|
now = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
workflows = [
|
||||||
|
Workflow(
|
||||||
|
id="simple_conversation",
|
||||||
|
user_id=1,
|
||||||
|
name="简单对话",
|
||||||
|
node_count=2,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
global_prompt="处理简单的对话流程,用户问什么答什么。",
|
||||||
|
nodes=[
|
||||||
|
{"id": "1", "type": "start", "position": {"x": 100, "y": 100}, "data": {"label": "开始"}},
|
||||||
|
{"id": "2", "type": "ai_reply", "position": {"x": 300, "y": 100}, "data": {"label": "AI回复"}},
|
||||||
|
],
|
||||||
|
edges=[{"source": "1", "target": "2", "id": "e1-2"}],
|
||||||
|
),
|
||||||
|
Workflow(
|
||||||
|
id="voice_input_flow",
|
||||||
|
user_id=1,
|
||||||
|
name="语音输入流程",
|
||||||
|
node_count=4,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
global_prompt="处理语音输入的完整流程。",
|
||||||
|
nodes=[
|
||||||
|
{"id": "1", "type": "start", "position": {"x": 100, "y": 100}, "data": {"label": "开始"}},
|
||||||
|
{"id": "2", "type": "asr", "position": {"x": 250, "y": 100}, "data": {"label": "语音识别"}},
|
||||||
|
{"id": "3", "type": "llm", "position": {"x": 400, "y": 100}, "data": {"label": "LLM处理"}},
|
||||||
|
{"id": "4", "type": "tts", "position": {"x": 550, "y": 100}, "data": {"label": "语音合成"}},
|
||||||
|
],
|
||||||
|
edges=[
|
||||||
|
{"source": "1", "target": "2", "id": "e1-2"},
|
||||||
|
{"source": "2", "target": "3", "id": "e2-3"},
|
||||||
|
{"source": "3", "target": "4", "id": "e3-4"},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for w in workflows:
|
||||||
|
db.add(w)
|
||||||
|
db.commit()
|
||||||
|
print("✅ 默认工作流数据已初始化")
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def init_default_knowledge_bases():
|
||||||
|
"""初始化默认知识库"""
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from app.db import SessionLocal
|
||||||
|
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
if db.query(KnowledgeBase).count() == 0:
|
||||||
|
kb = KnowledgeBase(
|
||||||
|
id="default_kb",
|
||||||
|
user_id=1,
|
||||||
|
name="默认知识库",
|
||||||
|
description="系统默认知识库,用于存储常见问题解答。",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
chunk_size=500,
|
||||||
|
chunk_overlap=50,
|
||||||
|
doc_count=0,
|
||||||
|
chunk_count=0,
|
||||||
|
status="active",
|
||||||
|
)
|
||||||
|
db.add(kb)
|
||||||
|
db.commit()
|
||||||
|
print("✅ 默认知识库已初始化")
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def init_default_llm_models():
|
||||||
|
"""初始化默认LLM模型"""
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from app.db import SessionLocal
|
||||||
|
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
if db.query(LLMModel).count() == 0:
|
||||||
|
llm_models = [
|
||||||
|
LLMModel(
|
||||||
|
id="deepseek-chat",
|
||||||
|
user_id=1,
|
||||||
|
name="DeepSeek Chat",
|
||||||
|
vendor="SiliconFlow",
|
||||||
|
type="text",
|
||||||
|
base_url="https://api.deepseek.com",
|
||||||
|
api_key="YOUR_API_KEY", # 用户需替换
|
||||||
|
model_name="deepseek-chat",
|
||||||
|
temperature=0.7,
|
||||||
|
context_length=4096,
|
||||||
|
enabled=True,
|
||||||
|
),
|
||||||
|
LLMModel(
|
||||||
|
id="deepseek-reasoner",
|
||||||
|
user_id=1,
|
||||||
|
name="DeepSeek Reasoner",
|
||||||
|
vendor="SiliconFlow",
|
||||||
|
type="text",
|
||||||
|
base_url="https://api.deepseek.com",
|
||||||
|
api_key="YOUR_API_KEY",
|
||||||
|
model_name="deepseek-reasoner",
|
||||||
|
temperature=0.7,
|
||||||
|
context_length=4096,
|
||||||
|
enabled=True,
|
||||||
|
),
|
||||||
|
LLMModel(
|
||||||
|
id="gpt-4o",
|
||||||
|
user_id=1,
|
||||||
|
name="GPT-4o",
|
||||||
|
vendor="OpenAI",
|
||||||
|
type="text",
|
||||||
|
base_url="https://api.openai.com/v1",
|
||||||
|
api_key="YOUR_API_KEY",
|
||||||
|
model_name="gpt-4o",
|
||||||
|
temperature=0.7,
|
||||||
|
context_length=16384,
|
||||||
|
enabled=True,
|
||||||
|
),
|
||||||
|
LLMModel(
|
||||||
|
id="glm-4",
|
||||||
|
user_id=1,
|
||||||
|
name="GLM-4",
|
||||||
|
vendor="ZhipuAI",
|
||||||
|
type="text",
|
||||||
|
base_url="https://open.bigmodel.cn/api/paas/v4",
|
||||||
|
api_key="YOUR_API_KEY",
|
||||||
|
model_name="glm-4",
|
||||||
|
temperature=0.7,
|
||||||
|
context_length=8192,
|
||||||
|
enabled=True,
|
||||||
|
),
|
||||||
|
LLMModel(
|
||||||
|
id="text-embedding-3-small",
|
||||||
|
user_id=1,
|
||||||
|
name="Embedding 3 Small",
|
||||||
|
vendor="OpenAI",
|
||||||
|
type="embedding",
|
||||||
|
base_url="https://api.openai.com/v1",
|
||||||
|
api_key="YOUR_API_KEY",
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
enabled=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for m in llm_models:
|
||||||
|
db.add(m)
|
||||||
|
db.commit()
|
||||||
|
print("✅ 默认LLM模型已初始化")
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def init_default_asr_models():
|
||||||
|
"""初始化默认ASR模型"""
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from app.db import SessionLocal
|
||||||
|
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
if db.query(ASRModel).count() == 0:
|
||||||
|
asr_models = [
|
||||||
|
ASRModel(
|
||||||
|
id="paraformer-v2",
|
||||||
|
user_id=1,
|
||||||
|
name="Paraformer V2",
|
||||||
|
vendor="SiliconFlow",
|
||||||
|
language="zh",
|
||||||
|
base_url="https://api.siliconflow.cn/v1",
|
||||||
|
api_key="YOUR_API_KEY",
|
||||||
|
model_name="paraformer-v2",
|
||||||
|
hotwords=["人工智能", "机器学习"],
|
||||||
|
enable_punctuation=True,
|
||||||
|
enable_normalization=True,
|
||||||
|
enabled=True,
|
||||||
|
),
|
||||||
|
ASRModel(
|
||||||
|
id="paraformer-en",
|
||||||
|
user_id=1,
|
||||||
|
name="Paraformer English",
|
||||||
|
vendor="SiliconFlow",
|
||||||
|
language="en",
|
||||||
|
base_url="https://api.siliconflow.cn/v1",
|
||||||
|
api_key="YOUR_API_KEY",
|
||||||
|
model_name="paraformer-en",
|
||||||
|
hotwords=[],
|
||||||
|
enable_punctuation=True,
|
||||||
|
enable_normalization=True,
|
||||||
|
enabled=True,
|
||||||
|
),
|
||||||
|
ASRModel(
|
||||||
|
id="whisper-1",
|
||||||
|
user_id=1,
|
||||||
|
name="Whisper",
|
||||||
|
vendor="OpenAI",
|
||||||
|
language="Multi-lingual",
|
||||||
|
base_url="https://api.openai.com/v1",
|
||||||
|
api_key="YOUR_API_KEY",
|
||||||
|
model_name="whisper-1",
|
||||||
|
hotwords=[],
|
||||||
|
enable_punctuation=True,
|
||||||
|
enable_normalization=True,
|
||||||
|
enabled=True,
|
||||||
|
),
|
||||||
|
ASRModel(
|
||||||
|
id="sensevoice",
|
||||||
|
user_id=1,
|
||||||
|
name="SenseVoice",
|
||||||
|
vendor="SiliconFlow",
|
||||||
|
language="Multi-lingual",
|
||||||
|
base_url="https://api.siliconflow.cn/v1",
|
||||||
|
api_key="YOUR_API_KEY",
|
||||||
|
model_name="sensevoice",
|
||||||
|
hotwords=[],
|
||||||
|
enable_punctuation=True,
|
||||||
|
enable_normalization=True,
|
||||||
|
enabled=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for m in asr_models:
|
||||||
|
db.add(m)
|
||||||
|
db.commit()
|
||||||
|
print("✅ 默认ASR模型已初始化")
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
@@ -48,5 +391,10 @@ if __name__ == "__main__":
|
|||||||
os.makedirs(data_dir, exist_ok=True)
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
|
|
||||||
init_db()
|
init_db()
|
||||||
init_default_voices()
|
init_default_data()
|
||||||
|
init_default_assistants()
|
||||||
|
init_default_workflows()
|
||||||
|
init_default_knowledge_bases()
|
||||||
|
init_default_llm_models()
|
||||||
|
init_default_asr_models()
|
||||||
print("🎉 数据库初始化完成!")
|
print("🎉 数据库初始化完成!")
|
||||||
|
|||||||
54
api/main.py
54
api/main.py
@@ -6,6 +6,9 @@ import os
|
|||||||
from app.db import Base, engine
|
from app.db import Base, engine
|
||||||
from app.routers import assistants, history, knowledge
|
from app.routers import assistants, history, knowledge
|
||||||
|
|
||||||
|
# 配置
|
||||||
|
PORT = int(os.getenv("PORT", 8100))
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
@@ -57,17 +60,54 @@ def init_default_data():
|
|||||||
try:
|
try:
|
||||||
# 检查是否已有数据
|
# 检查是否已有数据
|
||||||
if db.query(Voice).count() == 0:
|
if db.query(Voice).count() == 0:
|
||||||
# 插入默认声音
|
# SiliconFlow CosyVoice 2.0 预设声音 (8个)
|
||||||
|
# 参考: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech
|
||||||
voices = [
|
voices = [
|
||||||
Voice(id="v1", name="Xiaoyun", vendor="Ali", gender="Female", language="zh", description="Gentle and professional."),
|
# 男声 (Male Voices)
|
||||||
Voice(id="v2", name="Kevin", vendor="Volcano", gender="Male", language="en", description="Deep and authoritative."),
|
Voice(id="alex", name="Alex", vendor="SiliconFlow", gender="Male", language="en",
|
||||||
Voice(id="v3", name="Abby", vendor="Minimax", gender="Female", language="en", description="Cheerful and lively."),
|
description="Steady male voice.", is_system=True),
|
||||||
Voice(id="v4", name="Guang", vendor="Ali", gender="Male", language="zh", description="Standard newscast style."),
|
Voice(id="benjamin", name="Benjamin", vendor="SiliconFlow", gender="Male", language="en",
|
||||||
Voice(id="v5", name="Doubao", vendor="Volcano", gender="Female", language="zh", description="Cute and young."),
|
description="Deep male voice.", is_system=True),
|
||||||
|
Voice(id="charles", name="Charles", vendor="SiliconFlow", gender="Male", language="en",
|
||||||
|
description="Magnetic male voice.", is_system=True),
|
||||||
|
Voice(id="david", name="David", vendor="SiliconFlow", gender="Male", language="en",
|
||||||
|
description="Cheerful male voice.", is_system=True),
|
||||||
|
# 女声 (Female Voices)
|
||||||
|
Voice(id="anna", name="Anna", vendor="SiliconFlow", gender="Female", language="en",
|
||||||
|
description="Steady female voice.", is_system=True),
|
||||||
|
Voice(id="bella", name="Bella", vendor="SiliconFlow", gender="Female", language="en",
|
||||||
|
description="Passionate female voice.", is_system=True),
|
||||||
|
Voice(id="claire", name="Claire", vendor="SiliconFlow", gender="Female", language="en",
|
||||||
|
description="Gentle female voice.", is_system=True),
|
||||||
|
Voice(id="diana", name="Diana", vendor="SiliconFlow", gender="Female", language="en",
|
||||||
|
description="Cheerful female voice.", is_system=True),
|
||||||
|
# 中文方言 (Chinese Dialects) - 可选扩展
|
||||||
|
Voice(id="amador", name="Amador", vendor="SiliconFlow", gender="Male", language="zh",
|
||||||
|
description="Male voice with Spanish accent."),
|
||||||
|
Voice(id="aelora", name="Aelora", vendor="SiliconFlow", gender="Female", language="en",
|
||||||
|
description="Elegant female voice."),
|
||||||
|
Voice(id="aelwin", name="Aelwin", vendor="SiliconFlow", gender="Male", language="en",
|
||||||
|
description="Deep male voice."),
|
||||||
|
Voice(id="blooming", name="Blooming", vendor="SiliconFlow", gender="Female", language="en",
|
||||||
|
description="Fresh and clear female voice."),
|
||||||
|
Voice(id="elysia", name="Elysia", vendor="SiliconFlow", gender="Female", language="en",
|
||||||
|
description="Smooth and silky female voice."),
|
||||||
|
Voice(id="leo", name="Leo", vendor="SiliconFlow", gender="Male", language="en",
|
||||||
|
description="Young male voice."),
|
||||||
|
Voice(id="lin", name="Lin", vendor="SiliconFlow", gender="Female", language="zh",
|
||||||
|
description="Standard Chinese female voice."),
|
||||||
|
Voice(id="rose", name="Rose", vendor="SiliconFlow", gender="Female", language="en",
|
||||||
|
description="Soft and gentle female voice."),
|
||||||
|
Voice(id="shao", name="Shao", vendor="SiliconFlow", gender="Male", language="zh",
|
||||||
|
description="Deep Chinese male voice."),
|
||||||
|
Voice(id="sky", name="Sky", vendor="SiliconFlow", gender="Male", language="en",
|
||||||
|
description="Clear and bright male voice."),
|
||||||
|
Voice(id="ael西山", name="Ael西山", vendor="SiliconFlow", gender="Female", language="zh",
|
||||||
|
description="Female voice with Chinese dialect."),
|
||||||
]
|
]
|
||||||
for v in voices:
|
for v in voices:
|
||||||
db.add(v)
|
db.add(v)
|
||||||
db.commit()
|
db.commit()
|
||||||
print("✅ 默认声音数据已初始化")
|
print("✅ 默认声音数据已初始化 (SiliconFlow CosyVoice 2.0)")
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|||||||
8
api/pytest.ini
Normal file
8
api/pytest.ini
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
[pytest]
|
||||||
|
testpaths = tests
|
||||||
|
python_files = test_*.py
|
||||||
|
python_classes = Test*
|
||||||
|
python_functions = test_*
|
||||||
|
addopts = -v --tb=short
|
||||||
|
filterwarnings =
|
||||||
|
ignore::DeprecationWarning
|
||||||
14
api/run_tests.bat
Normal file
14
api/run_tests.bat
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
@echo off
|
||||||
|
REM Run API tests
|
||||||
|
|
||||||
|
cd /d "%~dp0"
|
||||||
|
|
||||||
|
REM Install test dependencies
|
||||||
|
echo Installing test dependencies...
|
||||||
|
pip install pytest pytest-cov -q
|
||||||
|
|
||||||
|
REM Run tests
|
||||||
|
echo Running tests...
|
||||||
|
pytest tests/ -v --tb=short
|
||||||
|
|
||||||
|
pause
|
||||||
1
api/tests/__init__.py
Normal file
1
api/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Tests package
|
||||||
137
api/tests/conftest.py
Normal file
137
api/tests/conftest.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
"""Pytest fixtures for API tests"""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Add api directory to path
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
from app.db import Base, get_db
|
||||||
|
from app.main import app
|
||||||
|
|
||||||
|
|
||||||
|
# Use in-memory SQLite for testing
|
||||||
|
DATABASE_URL = "sqlite:///:memory:"
|
||||||
|
|
||||||
|
engine = create_engine(
|
||||||
|
DATABASE_URL,
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def db_session():
|
||||||
|
"""Create a fresh database session for each test"""
|
||||||
|
# Create all tables
|
||||||
|
Base.metadata.create_all(bind=engine)
|
||||||
|
|
||||||
|
session = TestingSessionLocal()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
# Drop all tables after test
|
||||||
|
Base.metadata.drop_all(bind=engine)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def client(db_session):
|
||||||
|
"""Create a test client with database dependency override"""
|
||||||
|
|
||||||
|
def override_get_db():
|
||||||
|
try:
|
||||||
|
yield db_session
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
|
||||||
|
with TestClient(app) as test_client:
|
||||||
|
yield test_client
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_voice_data():
|
||||||
|
"""Sample voice data for testing"""
|
||||||
|
return {
|
||||||
|
"name": "Test Voice",
|
||||||
|
"vendor": "TestVendor",
|
||||||
|
"gender": "Female",
|
||||||
|
"language": "zh",
|
||||||
|
"description": "A test voice for unit testing",
|
||||||
|
"model": "test-model",
|
||||||
|
"voice_key": "test-key",
|
||||||
|
"speed": 1.0,
|
||||||
|
"gain": 0,
|
||||||
|
"pitch": 0,
|
||||||
|
"enabled": True
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_assistant_data():
|
||||||
|
"""Sample assistant data for testing"""
|
||||||
|
return {
|
||||||
|
"name": "Test Assistant",
|
||||||
|
"opener": "Hello, welcome!",
|
||||||
|
"prompt": "You are a helpful assistant.",
|
||||||
|
"language": "zh",
|
||||||
|
"speed": 1.0,
|
||||||
|
"hotwords": ["test", "hello"],
|
||||||
|
"tools": [],
|
||||||
|
"configMode": "platform"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_call_record_data():
|
||||||
|
"""Sample call record data for testing"""
|
||||||
|
return {
|
||||||
|
"user_id": 1,
|
||||||
|
"assistant_id": None,
|
||||||
|
"source": "debug"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_llm_model_data():
|
||||||
|
"""Sample LLM model data for testing"""
|
||||||
|
return {
|
||||||
|
"id": "test-llm-001",
|
||||||
|
"name": "Test LLM Model",
|
||||||
|
"vendor": "TestVendor",
|
||||||
|
"type": "text",
|
||||||
|
"base_url": "https://api.test.com/v1",
|
||||||
|
"api_key": "test-api-key",
|
||||||
|
"model_name": "test-model",
|
||||||
|
"temperature": 0.7,
|
||||||
|
"context_length": 4096,
|
||||||
|
"enabled": True
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_asr_model_data():
|
||||||
|
"""Sample ASR model data for testing"""
|
||||||
|
return {
|
||||||
|
"id": "test-asr-001",
|
||||||
|
"name": "Test ASR Model",
|
||||||
|
"vendor": "TestVendor",
|
||||||
|
"language": "zh",
|
||||||
|
"base_url": "https://api.test.com/v1",
|
||||||
|
"api_key": "test-api-key",
|
||||||
|
"model_name": "paraformer-v2",
|
||||||
|
"hotwords": ["测试", "语音"],
|
||||||
|
"enable_punctuation": True,
|
||||||
|
"enable_normalization": True,
|
||||||
|
"enabled": True
|
||||||
|
}
|
||||||
289
api/tests/test_asr.py
Normal file
289
api/tests/test_asr.py
Normal file
@@ -0,0 +1,289 @@
|
|||||||
|
"""Tests for ASR Model API endpoints"""
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
|
||||||
|
class TestASRModelAPI:
|
||||||
|
"""Test cases for ASR Model endpoints"""
|
||||||
|
|
||||||
|
def test_get_asr_models_empty(self, client):
|
||||||
|
"""Test getting ASR models when database is empty"""
|
||||||
|
response = client.get("/api/asr")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "total" in data
|
||||||
|
assert "list" in data
|
||||||
|
assert data["total"] == 0
|
||||||
|
|
||||||
|
def test_create_asr_model(self, client, sample_asr_model_data):
|
||||||
|
"""Test creating a new ASR model"""
|
||||||
|
response = client.post("/api/asr", json=sample_asr_model_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == sample_asr_model_data["name"]
|
||||||
|
assert data["vendor"] == sample_asr_model_data["vendor"]
|
||||||
|
assert data["language"] == sample_asr_model_data["language"]
|
||||||
|
assert "id" in data
|
||||||
|
|
||||||
|
def test_create_asr_model_minimal(self, client):
|
||||||
|
"""Test creating an ASR model with minimal required data"""
|
||||||
|
data = {
|
||||||
|
"name": "Minimal ASR",
|
||||||
|
"vendor": "Test",
|
||||||
|
"language": "zh",
|
||||||
|
"base_url": "https://api.test.com",
|
||||||
|
"api_key": "test-key"
|
||||||
|
}
|
||||||
|
response = client.post("/api/asr", json=data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["name"] == "Minimal ASR"
|
||||||
|
|
||||||
|
def test_get_asr_model_by_id(self, client, sample_asr_model_data):
|
||||||
|
"""Test getting a specific ASR model by ID"""
|
||||||
|
# Create first
|
||||||
|
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||||
|
model_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Get by ID
|
||||||
|
response = client.get(f"/api/asr/{model_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] == model_id
|
||||||
|
assert data["name"] == sample_asr_model_data["name"]
|
||||||
|
|
||||||
|
def test_get_asr_model_not_found(self, client):
|
||||||
|
"""Test getting a non-existent ASR model"""
|
||||||
|
response = client.get("/api/asr/non-existent-id")
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
def test_update_asr_model(self, client, sample_asr_model_data):
|
||||||
|
"""Test updating an ASR model"""
|
||||||
|
# Create first
|
||||||
|
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||||
|
model_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Update
|
||||||
|
update_data = {
|
||||||
|
"name": "Updated ASR Model",
|
||||||
|
"language": "en",
|
||||||
|
"enable_punctuation": False
|
||||||
|
}
|
||||||
|
response = client.put(f"/api/asr/{model_id}", json=update_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == "Updated ASR Model"
|
||||||
|
assert data["language"] == "en"
|
||||||
|
assert data["enable_punctuation"] == False
|
||||||
|
|
||||||
|
def test_delete_asr_model(self, client, sample_asr_model_data):
|
||||||
|
"""Test deleting an ASR model"""
|
||||||
|
# Create first
|
||||||
|
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||||
|
model_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Delete
|
||||||
|
response = client.delete(f"/api/asr/{model_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Verify deleted
|
||||||
|
get_response = client.get(f"/api/asr/{model_id}")
|
||||||
|
assert get_response.status_code == 404
|
||||||
|
|
||||||
|
def test_list_asr_models_with_pagination(self, client, sample_asr_model_data):
|
||||||
|
"""Test listing ASR models with pagination"""
|
||||||
|
# Create multiple models
|
||||||
|
for i in range(3):
|
||||||
|
data = sample_asr_model_data.copy()
|
||||||
|
data["id"] = f"test-asr-{i}"
|
||||||
|
data["name"] = f"ASR Model {i}"
|
||||||
|
client.post("/api/asr", json=data)
|
||||||
|
|
||||||
|
# Test pagination
|
||||||
|
response = client.get("/api/asr?page=1&limit=2")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 3
|
||||||
|
assert len(data["list"]) == 2
|
||||||
|
|
||||||
|
def test_filter_asr_models_by_language(self, client, sample_asr_model_data):
|
||||||
|
"""Test filtering ASR models by language"""
|
||||||
|
# Create models with different languages
|
||||||
|
for i, lang in enumerate(["zh", "en", "Multi-lingual"]):
|
||||||
|
data = sample_asr_model_data.copy()
|
||||||
|
data["id"] = f"test-asr-{lang}"
|
||||||
|
data["name"] = f"ASR {lang}"
|
||||||
|
data["language"] = lang
|
||||||
|
client.post("/api/asr", json=data)
|
||||||
|
|
||||||
|
# Filter by language
|
||||||
|
response = client.get("/api/asr?language=zh")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] >= 1
|
||||||
|
for model in data["list"]:
|
||||||
|
assert model["language"] == "zh"
|
||||||
|
|
||||||
|
def test_filter_asr_models_by_enabled(self, client, sample_asr_model_data):
|
||||||
|
"""Test filtering ASR models by enabled status"""
|
||||||
|
# Create enabled and disabled models
|
||||||
|
data = sample_asr_model_data.copy()
|
||||||
|
data["id"] = "test-asr-enabled"
|
||||||
|
data["name"] = "Enabled ASR"
|
||||||
|
data["enabled"] = True
|
||||||
|
client.post("/api/asr", json=data)
|
||||||
|
|
||||||
|
data["id"] = "test-asr-disabled"
|
||||||
|
data["name"] = "Disabled ASR"
|
||||||
|
data["enabled"] = False
|
||||||
|
client.post("/api/asr", json=data)
|
||||||
|
|
||||||
|
# Filter by enabled
|
||||||
|
response = client.get("/api/asr?enabled=true")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
for model in data["list"]:
|
||||||
|
assert model["enabled"] == True
|
||||||
|
|
||||||
|
def test_create_asr_model_with_hotwords(self, client):
|
||||||
|
"""Test creating an ASR model with hotwords"""
|
||||||
|
data = {
|
||||||
|
"id": "asr-hotwords",
|
||||||
|
"name": "ASR with Hotwords",
|
||||||
|
"vendor": "SiliconFlow",
|
||||||
|
"language": "zh",
|
||||||
|
"base_url": "https://api.siliconflow.cn/v1",
|
||||||
|
"api_key": "test-key",
|
||||||
|
"model_name": "paraformer-v2",
|
||||||
|
"hotwords": ["你好", "查询", "帮助"],
|
||||||
|
"enable_punctuation": True,
|
||||||
|
"enable_normalization": True
|
||||||
|
}
|
||||||
|
response = client.post("/api/asr", json=data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
result = response.json()
|
||||||
|
assert result["hotwords"] == ["你好", "查询", "帮助"]
|
||||||
|
|
||||||
|
def test_create_asr_model_with_all_fields(self, client):
|
||||||
|
"""Test creating an ASR model with all fields"""
|
||||||
|
data = {
|
||||||
|
"id": "full-asr",
|
||||||
|
"name": "Full ASR Model",
|
||||||
|
"vendor": "SiliconFlow",
|
||||||
|
"language": "zh",
|
||||||
|
"base_url": "https://api.siliconflow.cn/v1",
|
||||||
|
"api_key": "sk-test",
|
||||||
|
"model_name": "paraformer-v2",
|
||||||
|
"hotwords": ["测试"],
|
||||||
|
"enable_punctuation": True,
|
||||||
|
"enable_normalization": True,
|
||||||
|
"enabled": True
|
||||||
|
}
|
||||||
|
response = client.post("/api/asr", json=data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
result = response.json()
|
||||||
|
assert result["name"] == "Full ASR Model"
|
||||||
|
assert result["enable_punctuation"] == True
|
||||||
|
assert result["enable_normalization"] == True
|
||||||
|
|
||||||
|
@patch('httpx.Client')
|
||||||
|
def test_test_asr_model_siliconflow(self, mock_client_class, client, sample_asr_model_data):
|
||||||
|
"""Test testing an ASR model with SiliconFlow vendor"""
|
||||||
|
# Create model first
|
||||||
|
sample_asr_model_data["vendor"] = "SiliconFlow"
|
||||||
|
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||||
|
model_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Mock the HTTP response
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"results": [{"transcript": "测试文本", "language": "zh"}]
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
mock_client.get.return_value = mock_response
|
||||||
|
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||||
|
mock_client.__exit__ = MagicMock(return_value=False)
|
||||||
|
|
||||||
|
with patch('app.routers.asr.httpx.Client', return_value=mock_client):
|
||||||
|
response = client.post(f"/api/asr/{model_id}/test")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["success"] == True
|
||||||
|
|
||||||
|
@patch('httpx.Client')
|
||||||
|
def test_test_asr_model_openai(self, mock_client_class, client, sample_asr_model_data):
|
||||||
|
"""Test testing an ASR model with OpenAI vendor"""
|
||||||
|
# Create model with OpenAI vendor
|
||||||
|
sample_asr_model_data["vendor"] = "OpenAI"
|
||||||
|
sample_asr_model_data["id"] = "test-asr-openai"
|
||||||
|
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||||
|
model_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Mock the HTTP response
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {"text": "Test transcript"}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
mock_client.get.return_value = mock_response
|
||||||
|
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||||
|
mock_client.__exit__ = MagicMock(return_value=False)
|
||||||
|
|
||||||
|
with patch('app.routers.asr.httpx.Client', return_value=mock_client):
|
||||||
|
response = client.post(f"/api/asr/{model_id}/test")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
@patch('httpx.Client')
|
||||||
|
def test_test_asr_model_failure(self, mock_client_class, client, sample_asr_model_data):
|
||||||
|
"""Test testing an ASR model with failed connection"""
|
||||||
|
# Create model first
|
||||||
|
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||||
|
model_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Mock HTTP error
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 401
|
||||||
|
mock_response.text = "Unauthorized"
|
||||||
|
mock_response.raise_for_status = MagicMock(side_effect=Exception("401 Unauthorized"))
|
||||||
|
mock_client.get.return_value = mock_response
|
||||||
|
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||||
|
mock_client.__exit__ = MagicMock(return_value=False)
|
||||||
|
|
||||||
|
with patch('app.routers.asr.httpx.Client', return_value=mock_client):
|
||||||
|
response = client.post(f"/api/asr/{model_id}/test")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["success"] == False
|
||||||
|
|
||||||
|
def test_different_asr_languages(self, client):
|
||||||
|
"""Test creating ASR models with different languages"""
|
||||||
|
for lang in ["zh", "en", "Multi-lingual"]:
|
||||||
|
data = {
|
||||||
|
"id": f"asr-lang-{lang}",
|
||||||
|
"name": f"ASR {lang}",
|
||||||
|
"vendor": "SiliconFlow",
|
||||||
|
"language": lang,
|
||||||
|
"base_url": "https://api.siliconflow.cn/v1",
|
||||||
|
"api_key": "test-key"
|
||||||
|
}
|
||||||
|
response = client.post("/api/asr", json=data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["language"] == lang
|
||||||
|
|
||||||
|
def test_different_asr_vendors(self, client):
|
||||||
|
"""Test creating ASR models with different vendors"""
|
||||||
|
vendors = ["SiliconFlow", "OpenAI", "Azure"]
|
||||||
|
for vendor in vendors:
|
||||||
|
data = {
|
||||||
|
"id": f"asr-vendor-{vendor.lower()}",
|
||||||
|
"name": f"ASR {vendor}",
|
||||||
|
"vendor": vendor,
|
||||||
|
"language": "zh",
|
||||||
|
"base_url": f"https://api.{vendor.lower()}.com/v1",
|
||||||
|
"api_key": "test-key"
|
||||||
|
}
|
||||||
|
response = client.post("/api/asr", json=data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["vendor"] == vendor
|
||||||
168
api/tests/test_assistants.py
Normal file
168
api/tests/test_assistants.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""Tests for Assistant API endpoints"""
|
||||||
|
import pytest
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssistantAPI:
|
||||||
|
"""Test cases for Assistant endpoints"""
|
||||||
|
|
||||||
|
def test_get_assistants_empty(self, client):
|
||||||
|
"""Test getting assistants when database is empty"""
|
||||||
|
response = client.get("/api/assistants")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "total" in data
|
||||||
|
assert "list" in data
|
||||||
|
|
||||||
|
def test_create_assistant(self, client, sample_assistant_data):
|
||||||
|
"""Test creating a new assistant"""
|
||||||
|
response = client.post("/api/assistants", json=sample_assistant_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == sample_assistant_data["name"]
|
||||||
|
assert data["opener"] == sample_assistant_data["opener"]
|
||||||
|
assert data["prompt"] == sample_assistant_data["prompt"]
|
||||||
|
assert data["language"] == sample_assistant_data["language"]
|
||||||
|
assert "id" in data
|
||||||
|
assert data["callCount"] == 0
|
||||||
|
|
||||||
|
def test_create_assistant_minimal(self, client):
|
||||||
|
"""Test creating an assistant with minimal required data"""
|
||||||
|
data = {"name": "Minimal Assistant"}
|
||||||
|
response = client.post("/api/assistants", json=data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["name"] == "Minimal Assistant"
|
||||||
|
|
||||||
|
def test_get_assistant_by_id(self, client, sample_assistant_data):
|
||||||
|
"""Test getting a specific assistant by ID"""
|
||||||
|
# Create first
|
||||||
|
create_response = client.post("/api/assistants", json=sample_assistant_data)
|
||||||
|
assistant_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Get by ID
|
||||||
|
response = client.get(f"/api/assistants/{assistant_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] == assistant_id
|
||||||
|
assert data["name"] == sample_assistant_data["name"]
|
||||||
|
|
||||||
|
def test_get_assistant_not_found(self, client):
|
||||||
|
"""Test getting a non-existent assistant"""
|
||||||
|
response = client.get("/api/assistants/non-existent-id")
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
def test_update_assistant(self, client, sample_assistant_data):
|
||||||
|
"""Test updating an assistant"""
|
||||||
|
# Create first
|
||||||
|
create_response = client.post("/api/assistants", json=sample_assistant_data)
|
||||||
|
assistant_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Update
|
||||||
|
update_data = {
|
||||||
|
"name": "Updated Assistant",
|
||||||
|
"prompt": "You are an updated assistant.",
|
||||||
|
"speed": 1.5
|
||||||
|
}
|
||||||
|
response = client.put(f"/api/assistants/{assistant_id}", json=update_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == "Updated Assistant"
|
||||||
|
assert data["prompt"] == "You are an updated assistant."
|
||||||
|
assert data["speed"] == 1.5
|
||||||
|
|
||||||
|
def test_delete_assistant(self, client, sample_assistant_data):
|
||||||
|
"""Test deleting an assistant"""
|
||||||
|
# Create first
|
||||||
|
create_response = client.post("/api/assistants", json=sample_assistant_data)
|
||||||
|
assistant_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Delete
|
||||||
|
response = client.delete(f"/api/assistants/{assistant_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Verify deleted
|
||||||
|
get_response = client.get(f"/api/assistants/{assistant_id}")
|
||||||
|
assert get_response.status_code == 404
|
||||||
|
|
||||||
|
def test_list_assistants_with_pagination(self, client, sample_assistant_data):
|
||||||
|
"""Test listing assistants with pagination"""
|
||||||
|
# Create multiple assistants
|
||||||
|
for i in range(3):
|
||||||
|
data = sample_assistant_data.copy()
|
||||||
|
data["name"] = f"Assistant {i}"
|
||||||
|
client.post("/api/assistants", json=data)
|
||||||
|
|
||||||
|
# Test pagination
|
||||||
|
response = client.get("/api/assistants?page=1&limit=2")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 3
|
||||||
|
assert len(data["list"]) == 2
|
||||||
|
|
||||||
|
def test_create_assistant_with_voice(self, client, sample_assistant_data, sample_voice_data):
|
||||||
|
"""Test creating an assistant with a voice reference"""
|
||||||
|
# Create a voice first
|
||||||
|
voice_response = client.post("/api/voices", json=sample_voice_data)
|
||||||
|
voice_id = voice_response.json()["id"]
|
||||||
|
|
||||||
|
# Create assistant with voice
|
||||||
|
sample_assistant_data["voice"] = voice_id
|
||||||
|
response = client.post("/api/assistants", json=sample_assistant_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["voice"] == voice_id
|
||||||
|
|
||||||
|
def test_create_assistant_with_knowledge_base(self, client, sample_assistant_data):
|
||||||
|
"""Test creating an assistant with knowledge base reference"""
|
||||||
|
# Note: This test assumes knowledge base doesn't exist
|
||||||
|
sample_assistant_data["knowledgeBaseId"] = "non-existent-kb"
|
||||||
|
response = client.post("/api/assistants", json=sample_assistant_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["knowledgeBaseId"] == "non-existent-kb"
|
||||||
|
|
||||||
|
def test_assistant_with_model_references(self, client, sample_assistant_data):
|
||||||
|
"""Test creating assistant with model references"""
|
||||||
|
sample_assistant_data.update({
|
||||||
|
"llmModelId": "llm-001",
|
||||||
|
"asrModelId": "asr-001",
|
||||||
|
"embeddingModelId": "emb-001",
|
||||||
|
"rerankModelId": "rerank-001"
|
||||||
|
})
|
||||||
|
response = client.post("/api/assistants", json=sample_assistant_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["llmModelId"] == "llm-001"
|
||||||
|
assert data["asrModelId"] == "asr-001"
|
||||||
|
assert data["embeddingModelId"] == "emb-001"
|
||||||
|
assert data["rerankModelId"] == "rerank-001"
|
||||||
|
|
||||||
|
def test_assistant_with_tools(self, client, sample_assistant_data):
|
||||||
|
"""Test creating assistant with tools"""
|
||||||
|
sample_assistant_data["tools"] = ["weather", "calculator", "search"]
|
||||||
|
response = client.post("/api/assistants", json=sample_assistant_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["tools"] == ["weather", "calculator", "search"]
|
||||||
|
|
||||||
|
def test_assistant_with_hotwords(self, client, sample_assistant_data):
|
||||||
|
"""Test creating assistant with hotwords"""
|
||||||
|
sample_assistant_data["hotwords"] = ["hello", "help", "stop"]
|
||||||
|
response = client.post("/api/assistants", json=sample_assistant_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["hotwords"] == ["hello", "help", "stop"]
|
||||||
|
|
||||||
|
def test_different_config_modes(self, client, sample_assistant_data):
|
||||||
|
"""Test creating assistants with different config modes"""
|
||||||
|
for mode in ["platform", "dify", "fastgpt", "none"]:
|
||||||
|
sample_assistant_data["name"] = f"Assistant {mode}"
|
||||||
|
sample_assistant_data["configMode"] = mode
|
||||||
|
response = client.post("/api/assistants", json=sample_assistant_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["configMode"] == mode
|
||||||
|
|
||||||
|
def test_different_languages(self, client, sample_assistant_data):
|
||||||
|
"""Test creating assistants with different languages"""
|
||||||
|
for lang in ["zh", "en", "ja", "ko"]:
|
||||||
|
sample_assistant_data["name"] = f"Assistant {lang}"
|
||||||
|
sample_assistant_data["language"] = lang
|
||||||
|
response = client.post("/api/assistants", json=sample_assistant_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["language"] == lang
|
||||||
236
api/tests/test_history.py
Normal file
236
api/tests/test_history.py
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
"""Tests for History/Call Record API endpoints"""
|
||||||
|
import pytest
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class TestHistoryAPI:
|
||||||
|
"""Test cases for History/Call Record endpoints"""
|
||||||
|
|
||||||
|
def test_get_history_empty(self, client):
|
||||||
|
"""Test getting history when database is empty"""
|
||||||
|
response = client.get("/api/history")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "total" in data
|
||||||
|
assert "list" in data
|
||||||
|
|
||||||
|
def test_create_call_record(self, client, sample_call_record_data):
|
||||||
|
"""Test creating a new call record"""
|
||||||
|
response = client.post("/api/history", json=sample_call_record_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["user_id"] == sample_call_record_data["user_id"]
|
||||||
|
assert data["source"] == sample_call_record_data["source"]
|
||||||
|
assert data["status"] == "connected"
|
||||||
|
assert "id" in data
|
||||||
|
assert "started_at" in data
|
||||||
|
|
||||||
|
def test_create_call_record_with_assistant(self, client, sample_assistant_data, sample_call_record_data):
|
||||||
|
"""Test creating a call record associated with an assistant"""
|
||||||
|
# Create assistant first
|
||||||
|
assistant_response = client.post("/api/assistants", json=sample_assistant_data)
|
||||||
|
assistant_id = assistant_response.json()["id"]
|
||||||
|
|
||||||
|
# Create call record with assistant
|
||||||
|
sample_call_record_data["assistant_id"] = assistant_id
|
||||||
|
response = client.post("/api/history", json=sample_call_record_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["assistant_id"] == assistant_id
|
||||||
|
|
||||||
|
def test_get_call_record_by_id(self, client, sample_call_record_data):
|
||||||
|
"""Test getting a specific call record by ID"""
|
||||||
|
# Create first
|
||||||
|
create_response = client.post("/api/history", json=sample_call_record_data)
|
||||||
|
record_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Get by ID
|
||||||
|
response = client.get(f"/api/history/{record_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] == record_id
|
||||||
|
|
||||||
|
def test_get_call_record_not_found(self, client):
|
||||||
|
"""Test getting a non-existent call record"""
|
||||||
|
response = client.get("/api/history/non-existent-id")
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
def test_update_call_record(self, client, sample_call_record_data):
|
||||||
|
"""Test updating a call record"""
|
||||||
|
# Create first
|
||||||
|
create_response = client.post("/api/history", json=sample_call_record_data)
|
||||||
|
record_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Update
|
||||||
|
update_data = {
|
||||||
|
"status": "completed",
|
||||||
|
"summary": "Test summary",
|
||||||
|
"duration_seconds": 120
|
||||||
|
}
|
||||||
|
response = client.put(f"/api/history/{record_id}", json=update_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "completed"
|
||||||
|
assert data["summary"] == "Test summary"
|
||||||
|
assert data["duration_seconds"] == 120
|
||||||
|
|
||||||
|
def test_delete_call_record(self, client, sample_call_record_data):
|
||||||
|
"""Test deleting a call record"""
|
||||||
|
# Create first
|
||||||
|
create_response = client.post("/api/history", json=sample_call_record_data)
|
||||||
|
record_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Delete
|
||||||
|
response = client.delete(f"/api/history/{record_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Verify deleted
|
||||||
|
get_response = client.get(f"/api/history/{record_id}")
|
||||||
|
assert get_response.status_code == 404
|
||||||
|
|
||||||
|
def test_add_transcript(self, client, sample_call_record_data):
|
||||||
|
"""Test adding a transcript to a call record"""
|
||||||
|
# Create call record first
|
||||||
|
create_response = client.post("/api/history", json=sample_call_record_data)
|
||||||
|
record_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Add transcript
|
||||||
|
transcript_data = {
|
||||||
|
"turn_index": 0,
|
||||||
|
"speaker": "human",
|
||||||
|
"content": "Hello, I need help",
|
||||||
|
"start_ms": 0,
|
||||||
|
"end_ms": 3000,
|
||||||
|
"confidence": 0.95
|
||||||
|
}
|
||||||
|
response = client.post(
|
||||||
|
f"/api/history/{record_id}/transcripts",
|
||||||
|
json=transcript_data
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["turn_index"] == 0
|
||||||
|
assert data["speaker"] == "human"
|
||||||
|
assert data["content"] == "Hello, I need help"
|
||||||
|
|
||||||
|
def test_add_multiple_transcripts(self, client, sample_call_record_data):
|
||||||
|
"""Test adding multiple transcripts"""
|
||||||
|
# Create call record first
|
||||||
|
create_response = client.post("/api/history", json=sample_call_record_data)
|
||||||
|
record_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Add human transcript
|
||||||
|
human_transcript = {
|
||||||
|
"turn_index": 0,
|
||||||
|
"speaker": "human",
|
||||||
|
"content": "Hello",
|
||||||
|
"start_ms": 0,
|
||||||
|
"end_ms": 1000
|
||||||
|
}
|
||||||
|
client.post(f"/api/history/{record_id}/transcripts", json=human_transcript)
|
||||||
|
|
||||||
|
# Add AI transcript
|
||||||
|
ai_transcript = {
|
||||||
|
"turn_index": 1,
|
||||||
|
"speaker": "ai",
|
||||||
|
"content": "Hello! How can I help you?",
|
||||||
|
"start_ms": 1500,
|
||||||
|
"end_ms": 4000
|
||||||
|
}
|
||||||
|
client.post(f"/api/history/{record_id}/transcripts", json=ai_transcript)
|
||||||
|
|
||||||
|
# Verify both transcripts exist
|
||||||
|
response = client.get(f"/api/history/{record_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["transcripts"]) == 2
|
||||||
|
|
||||||
|
def test_filter_history_by_status(self, client, sample_call_record_data):
|
||||||
|
"""Test filtering history by status"""
|
||||||
|
# Create records with different statuses
|
||||||
|
for i in range(2):
|
||||||
|
data = sample_call_record_data.copy()
|
||||||
|
data["status"] = "connected" if i % 2 == 0 else "missed"
|
||||||
|
client.post("/api/history", json=data)
|
||||||
|
|
||||||
|
# Filter by status
|
||||||
|
response = client.get("/api/history?status=connected")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
for record in data["list"]:
|
||||||
|
assert record["status"] == "connected"
|
||||||
|
|
||||||
|
def test_filter_history_by_source(self, client, sample_call_record_data):
|
||||||
|
"""Test filtering history by source"""
|
||||||
|
sample_call_record_data["source"] = "external"
|
||||||
|
client.post("/api/history", json=sample_call_record_data)
|
||||||
|
|
||||||
|
response = client.get("/api/history?source=external")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
for record in data["list"]:
|
||||||
|
assert record["source"] == "external"
|
||||||
|
|
||||||
|
def test_history_pagination(self, client, sample_call_record_data):
|
||||||
|
"""Test history pagination"""
|
||||||
|
# Create multiple records
|
||||||
|
for i in range(5):
|
||||||
|
data = sample_call_record_data.copy()
|
||||||
|
data["source"] = f"source-{i}"
|
||||||
|
client.post("/api/history", json=data)
|
||||||
|
|
||||||
|
# Test pagination
|
||||||
|
response = client.get("/api/history?page=1&limit=3")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 5
|
||||||
|
assert len(data["list"]) == 3
|
||||||
|
|
||||||
|
def test_transcript_with_emotion(self, client, sample_call_record_data):
|
||||||
|
"""Test adding transcript with emotion"""
|
||||||
|
# Create call record first
|
||||||
|
create_response = client.post("/api/history", json=sample_call_record_data)
|
||||||
|
record_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Add transcript with emotion
|
||||||
|
transcript_data = {
|
||||||
|
"turn_index": 0,
|
||||||
|
"speaker": "ai",
|
||||||
|
"content": "Great news!",
|
||||||
|
"start_ms": 0,
|
||||||
|
"end_ms": 2000,
|
||||||
|
"emotion": "happy"
|
||||||
|
}
|
||||||
|
response = client.post(
|
||||||
|
f"/api/history/{record_id}/transcripts",
|
||||||
|
json=transcript_data
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["emotion"] == "happy"
|
||||||
|
|
||||||
|
def test_history_with_cost(self, client, sample_call_record_data):
|
||||||
|
"""Test creating history with cost"""
|
||||||
|
sample_call_record_data["cost"] = 0.05
|
||||||
|
response = client.post("/api/history", json=sample_call_record_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["cost"] == 0.05
|
||||||
|
|
||||||
|
def test_history_search(self, client, sample_call_record_data):
|
||||||
|
"""Test searching history"""
|
||||||
|
# Create record
|
||||||
|
create_response = client.post("/api/history", json=sample_call_record_data)
|
||||||
|
record_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Add transcript with searchable content
|
||||||
|
transcript_data = {
|
||||||
|
"turn_index": 0,
|
||||||
|
"speaker": "human",
|
||||||
|
"content": "I want to buy a product",
|
||||||
|
"start_ms": 0,
|
||||||
|
"end_ms": 3000
|
||||||
|
}
|
||||||
|
client.post(f"/api/history/{record_id}/transcripts", json=transcript_data)
|
||||||
|
|
||||||
|
# Search (this endpoint may not exist yet)
|
||||||
|
response = client.get("/api/history/search?q=product")
|
||||||
|
# This might return 404 if endpoint doesn't exist
|
||||||
|
assert response.status_code in [200, 404]
|
||||||
255
api/tests/test_knowledge.py
Normal file
255
api/tests/test_knowledge.py
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
"""Tests for Knowledge Base API endpoints"""
|
||||||
|
import pytest
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
class TestKnowledgeAPI:
|
||||||
|
"""Test cases for Knowledge Base endpoints"""
|
||||||
|
|
||||||
|
def test_get_knowledge_bases_empty(self, client):
|
||||||
|
"""Test getting knowledge bases when database is empty"""
|
||||||
|
response = client.get("/api/knowledge/bases")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "total" in data
|
||||||
|
assert "list" in data
|
||||||
|
|
||||||
|
def test_create_knowledge_base(self, client):
|
||||||
|
"""Test creating a new knowledge base"""
|
||||||
|
data = {
|
||||||
|
"name": "Test Knowledge Base",
|
||||||
|
"description": "A test knowledge base",
|
||||||
|
"embeddingModel": "text-embedding-3-small",
|
||||||
|
"chunkSize": 500,
|
||||||
|
"chunkOverlap": 50
|
||||||
|
}
|
||||||
|
response = client.post("/api/knowledge/bases", json=data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == "Test Knowledge Base"
|
||||||
|
assert data["description"] == "A test knowledge base"
|
||||||
|
assert data["embeddingModel"] == "text-embedding-3-small"
|
||||||
|
assert "id" in data
|
||||||
|
assert data["docCount"] == 0
|
||||||
|
assert data["chunkCount"] == 0
|
||||||
|
assert data["status"] == "active"
|
||||||
|
|
||||||
|
def test_create_knowledge_base_minimal(self, client):
|
||||||
|
"""Test creating a knowledge base with minimal data"""
|
||||||
|
data = {"name": "Minimal KB"}
|
||||||
|
response = client.post("/api/knowledge/bases", json=data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["name"] == "Minimal KB"
|
||||||
|
|
||||||
|
def test_get_knowledge_base_by_id(self, client):
|
||||||
|
"""Test getting a specific knowledge base by ID"""
|
||||||
|
# Create first
|
||||||
|
create_data = {"name": "Test KB"}
|
||||||
|
create_response = client.post("/api/knowledge/bases", json=create_data)
|
||||||
|
kb_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Get by ID
|
||||||
|
response = client.get(f"/api/knowledge/bases/{kb_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] == kb_id
|
||||||
|
assert data["name"] == "Test KB"
|
||||||
|
|
||||||
|
def test_get_knowledge_base_not_found(self, client):
|
||||||
|
"""Test getting a non-existent knowledge base"""
|
||||||
|
response = client.get("/api/knowledge/bases/non-existent-id")
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
def test_update_knowledge_base(self, client):
|
||||||
|
"""Test updating a knowledge base"""
|
||||||
|
# Create first
|
||||||
|
create_data = {"name": "Original Name"}
|
||||||
|
create_response = client.post("/api/knowledge/bases", json=create_data)
|
||||||
|
kb_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Update
|
||||||
|
update_data = {
|
||||||
|
"name": "Updated Name",
|
||||||
|
"description": "Updated description",
|
||||||
|
"chunkSize": 800
|
||||||
|
}
|
||||||
|
response = client.put(f"/api/knowledge/bases/{kb_id}", json=update_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == "Updated Name"
|
||||||
|
assert data["description"] == "Updated description"
|
||||||
|
assert data["chunkSize"] == 800
|
||||||
|
|
||||||
|
def test_delete_knowledge_base(self, client):
|
||||||
|
"""Test deleting a knowledge base"""
|
||||||
|
# Create first
|
||||||
|
create_data = {"name": "To Delete"}
|
||||||
|
create_response = client.post("/api/knowledge/bases", json=create_data)
|
||||||
|
kb_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Delete
|
||||||
|
response = client.delete(f"/api/knowledge/bases/{kb_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Verify deleted
|
||||||
|
get_response = client.get(f"/api/knowledge/bases/{kb_id}")
|
||||||
|
assert get_response.status_code == 404
|
||||||
|
|
||||||
|
def test_upload_document(self, client):
|
||||||
|
"""Test uploading a document to knowledge base"""
|
||||||
|
# Create KB first
|
||||||
|
create_data = {"name": "Test KB for Docs"}
|
||||||
|
create_response = client.post("/api/knowledge/bases", json=create_data)
|
||||||
|
kb_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Upload document
|
||||||
|
doc_data = {
|
||||||
|
"name": "test-document.txt",
|
||||||
|
"size": "1024",
|
||||||
|
"fileType": "txt",
|
||||||
|
"storageUrl": "https://storage.example.com/test-document.txt"
|
||||||
|
}
|
||||||
|
response = client.post(
|
||||||
|
f"/api/knowledge/bases/{kb_id}/documents",
|
||||||
|
json=doc_data
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == "test-document.txt"
|
||||||
|
assert "id" in data
|
||||||
|
assert data["status"] == "pending"
|
||||||
|
|
||||||
|
def test_delete_document(self, client):
|
||||||
|
"""Test deleting a document from knowledge base"""
|
||||||
|
# Create KB first
|
||||||
|
create_data = {"name": "Test KB for Delete"}
|
||||||
|
create_response = client.post("/api/knowledge/bases", json=create_data)
|
||||||
|
kb_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Upload document
|
||||||
|
doc_data = {"name": "to-delete.txt", "size": "100", "fileType": "txt"}
|
||||||
|
upload_response = client.post(
|
||||||
|
f"/api/knowledge/bases/{kb_id}/documents",
|
||||||
|
json=doc_data
|
||||||
|
)
|
||||||
|
doc_id = upload_response.json()["id"]
|
||||||
|
|
||||||
|
# Delete document
|
||||||
|
response = client.delete(
|
||||||
|
f"/api/knowledge/bases/{kb_id}/documents/{doc_id}"
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_index_document(self, client):
|
||||||
|
"""Test indexing a document"""
|
||||||
|
# Create KB first
|
||||||
|
create_data = {"name": "Test KB for Index"}
|
||||||
|
create_response = client.post("/api/knowledge/bases", json=create_data)
|
||||||
|
kb_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Index document
|
||||||
|
index_data = {
|
||||||
|
"document_id": "doc-001",
|
||||||
|
"content": "This is the content to index. It contains important information about the product."
|
||||||
|
}
|
||||||
|
response = client.post(
|
||||||
|
f"/api/knowledge/bases/{kb_id}/documents/doc-001/index",
|
||||||
|
json=index_data
|
||||||
|
)
|
||||||
|
# This might return 200 or error depending on vector store implementation
|
||||||
|
assert response.status_code in [200, 500]
|
||||||
|
|
||||||
|
def test_search_knowledge(self, client):
|
||||||
|
"""Test searching knowledge base"""
|
||||||
|
# Create KB first
|
||||||
|
create_data = {"name": "Test KB for Search"}
|
||||||
|
create_response = client.post("/api/knowledge/bases", json=create_data)
|
||||||
|
kb_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Search (this may fail without indexed content)
|
||||||
|
search_data = {
|
||||||
|
"query": "test query",
|
||||||
|
"kb_id": kb_id,
|
||||||
|
"nResults": 5
|
||||||
|
}
|
||||||
|
response = client.post("/api/knowledge/search", json=search_data)
|
||||||
|
# This might return 200 or error depending on implementation
|
||||||
|
assert response.status_code in [200, 500]
|
||||||
|
|
||||||
|
def test_get_knowledge_stats(self, client):
|
||||||
|
"""Test getting knowledge base statistics"""
|
||||||
|
# Create KB first
|
||||||
|
create_data = {"name": "Test KB for Stats"}
|
||||||
|
create_response = client.post("/api/knowledge/bases", json=create_data)
|
||||||
|
kb_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
response = client.get(f"/api/knowledge/bases/{kb_id}/stats")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["kb_id"] == kb_id
|
||||||
|
assert "docCount" in data
|
||||||
|
assert "chunkCount" in data
|
||||||
|
|
||||||
|
def test_knowledge_bases_pagination(self, client):
|
||||||
|
"""Test knowledge bases pagination"""
|
||||||
|
# Create multiple KBs
|
||||||
|
for i in range(5):
|
||||||
|
data = {"name": f"Knowledge Base {i}"}
|
||||||
|
client.post("/api/knowledge/bases", json=data)
|
||||||
|
|
||||||
|
# Test pagination
|
||||||
|
response = client.get("/api/knowledge/bases?page=1&limit=3")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 5
|
||||||
|
assert len(data["list"]) == 3
|
||||||
|
|
||||||
|
def test_different_embedding_models(self, client):
|
||||||
|
"""Test creating KB with different embedding models"""
|
||||||
|
models = [
|
||||||
|
"text-embedding-3-small",
|
||||||
|
"text-embedding-3-large",
|
||||||
|
"bge-small-zh"
|
||||||
|
]
|
||||||
|
for model in models:
|
||||||
|
data = {"name": f"KB with {model}", "embeddingModel": model}
|
||||||
|
response = client.post("/api/knowledge/bases", json=data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["embeddingModel"] == model
|
||||||
|
|
||||||
|
def test_different_chunk_sizes(self, client):
|
||||||
|
"""Test creating KB with different chunk configurations"""
|
||||||
|
configs = [
|
||||||
|
{"chunkSize": 500, "chunkOverlap": 50},
|
||||||
|
{"chunkSize": 1000, "chunkOverlap": 100},
|
||||||
|
{"chunkSize": 256, "chunkOverlap": 25}
|
||||||
|
]
|
||||||
|
for config in configs:
|
||||||
|
data = {"name": "Chunk Test KB", **config}
|
||||||
|
response = client.post("/api/knowledge/bases", json=data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_knowledge_base_with_documents(self, client):
|
||||||
|
"""Test creating KB and adding multiple documents"""
|
||||||
|
# Create KB
|
||||||
|
create_data = {"name": "KB with Multiple Docs"}
|
||||||
|
create_response = client.post("/api/knowledge/bases", json=create_data)
|
||||||
|
kb_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Add multiple documents
|
||||||
|
for i in range(3):
|
||||||
|
doc_data = {
|
||||||
|
"name": f"document-{i}.txt",
|
||||||
|
"size": f"{1000 + i * 100}",
|
||||||
|
"fileType": "txt"
|
||||||
|
}
|
||||||
|
client.post(
|
||||||
|
f"/api/knowledge/bases/{kb_id}/documents",
|
||||||
|
json=doc_data
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify documents are listed
|
||||||
|
response = client.get(f"/api/knowledge/bases/{kb_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["documents"]) == 3
|
||||||
246
api/tests/test_llm.py
Normal file
246
api/tests/test_llm.py
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
"""Tests for LLM Model API endpoints"""
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMModelAPI:
|
||||||
|
"""Test cases for LLM Model endpoints"""
|
||||||
|
|
||||||
|
def test_get_llm_models_empty(self, client):
|
||||||
|
"""Test getting LLM models when database is empty"""
|
||||||
|
response = client.get("/api/llm")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "total" in data
|
||||||
|
assert "list" in data
|
||||||
|
assert data["total"] == 0
|
||||||
|
|
||||||
|
def test_create_llm_model(self, client, sample_llm_model_data):
|
||||||
|
"""Test creating a new LLM model"""
|
||||||
|
response = client.post("/api/llm", json=sample_llm_model_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == sample_llm_model_data["name"]
|
||||||
|
assert data["vendor"] == sample_llm_model_data["vendor"]
|
||||||
|
assert data["type"] == sample_llm_model_data["type"]
|
||||||
|
assert data["base_url"] == sample_llm_model_data["base_url"]
|
||||||
|
assert "id" in data
|
||||||
|
|
||||||
|
def test_create_llm_model_minimal(self, client):
|
||||||
|
"""Test creating an LLM model with minimal required data"""
|
||||||
|
data = {
|
||||||
|
"name": "Minimal LLM",
|
||||||
|
"vendor": "Test",
|
||||||
|
"type": "text",
|
||||||
|
"base_url": "https://api.test.com",
|
||||||
|
"api_key": "test-key"
|
||||||
|
}
|
||||||
|
response = client.post("/api/llm", json=data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["name"] == "Minimal LLM"
|
||||||
|
|
||||||
|
def test_get_llm_model_by_id(self, client, sample_llm_model_data):
|
||||||
|
"""Test getting a specific LLM model by ID"""
|
||||||
|
# Create first
|
||||||
|
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||||
|
model_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Get by ID
|
||||||
|
response = client.get(f"/api/llm/{model_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] == model_id
|
||||||
|
assert data["name"] == sample_llm_model_data["name"]
|
||||||
|
|
||||||
|
def test_get_llm_model_not_found(self, client):
|
||||||
|
"""Test getting a non-existent LLM model"""
|
||||||
|
response = client.get("/api/llm/non-existent-id")
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
def test_update_llm_model(self, client, sample_llm_model_data):
|
||||||
|
"""Test updating an LLM model"""
|
||||||
|
# Create first
|
||||||
|
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||||
|
model_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Update
|
||||||
|
update_data = {
|
||||||
|
"name": "Updated LLM Model",
|
||||||
|
"temperature": 0.5,
|
||||||
|
"context_length": 8192
|
||||||
|
}
|
||||||
|
response = client.put(f"/api/llm/{model_id}", json=update_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == "Updated LLM Model"
|
||||||
|
assert data["temperature"] == 0.5
|
||||||
|
assert data["context_length"] == 8192
|
||||||
|
|
||||||
|
def test_delete_llm_model(self, client, sample_llm_model_data):
|
||||||
|
"""Test deleting an LLM model"""
|
||||||
|
# Create first
|
||||||
|
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||||
|
model_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Delete
|
||||||
|
response = client.delete(f"/api/llm/{model_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Verify deleted
|
||||||
|
get_response = client.get(f"/api/llm/{model_id}")
|
||||||
|
assert get_response.status_code == 404
|
||||||
|
|
||||||
|
def test_list_llm_models_with_pagination(self, client, sample_llm_model_data):
|
||||||
|
"""Test listing LLM models with pagination"""
|
||||||
|
# Create multiple models
|
||||||
|
for i in range(3):
|
||||||
|
data = sample_llm_model_data.copy()
|
||||||
|
data["id"] = f"test-llm-{i}"
|
||||||
|
data["name"] = f"LLM Model {i}"
|
||||||
|
client.post("/api/llm", json=data)
|
||||||
|
|
||||||
|
# Test pagination
|
||||||
|
response = client.get("/api/llm?page=1&limit=2")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 3
|
||||||
|
assert len(data["list"]) == 2
|
||||||
|
|
||||||
|
def test_filter_llm_models_by_type(self, client, sample_llm_model_data):
|
||||||
|
"""Test filtering LLM models by type"""
|
||||||
|
# Create models with different types
|
||||||
|
for i, model_type in enumerate(["text", "embedding", "rerank"]):
|
||||||
|
data = sample_llm_model_data.copy()
|
||||||
|
data["id"] = f"test-llm-{model_type}"
|
||||||
|
data["name"] = f"LLM {model_type}"
|
||||||
|
data["type"] = model_type
|
||||||
|
client.post("/api/llm", json=data)
|
||||||
|
|
||||||
|
# Filter by type
|
||||||
|
response = client.get("/api/llm?model_type=text")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] >= 1
|
||||||
|
for model in data["list"]:
|
||||||
|
assert model["type"] == "text"
|
||||||
|
|
||||||
|
def test_filter_llm_models_by_enabled(self, client, sample_llm_model_data):
|
||||||
|
"""Test filtering LLM models by enabled status"""
|
||||||
|
# Create enabled and disabled models
|
||||||
|
data = sample_llm_model_data.copy()
|
||||||
|
data["id"] = "test-llm-enabled"
|
||||||
|
data["name"] = "Enabled LLM"
|
||||||
|
data["enabled"] = True
|
||||||
|
client.post("/api/llm", json=data)
|
||||||
|
|
||||||
|
data["id"] = "test-llm-disabled"
|
||||||
|
data["name"] = "Disabled LLM"
|
||||||
|
data["enabled"] = False
|
||||||
|
client.post("/api/llm", json=data)
|
||||||
|
|
||||||
|
# Filter by enabled
|
||||||
|
response = client.get("/api/llm?enabled=true")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
for model in data["list"]:
|
||||||
|
assert model["enabled"] == True
|
||||||
|
|
||||||
|
def test_create_llm_model_with_all_fields(self, client):
|
||||||
|
"""Test creating an LLM model with all fields"""
|
||||||
|
data = {
|
||||||
|
"id": "full-llm",
|
||||||
|
"name": "Full LLM Model",
|
||||||
|
"vendor": "OpenAI",
|
||||||
|
"type": "text",
|
||||||
|
"base_url": "https://api.openai.com/v1",
|
||||||
|
"api_key": "sk-test",
|
||||||
|
"model_name": "gpt-4",
|
||||||
|
"temperature": 0.8,
|
||||||
|
"context_length": 16384,
|
||||||
|
"enabled": True
|
||||||
|
}
|
||||||
|
response = client.post("/api/llm", json=data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
result = response.json()
|
||||||
|
assert result["name"] == "Full LLM Model"
|
||||||
|
assert result["temperature"] == 0.8
|
||||||
|
assert result["context_length"] == 16384
|
||||||
|
|
||||||
|
@patch('httpx.Client')
|
||||||
|
def test_test_llm_model_success(self, mock_client_class, client, sample_llm_model_data):
|
||||||
|
"""Test testing an LLM model with successful connection"""
|
||||||
|
# Create model first
|
||||||
|
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||||
|
model_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Mock the HTTP response
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"choices": [{"message": {"content": "OK"}}]
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||||
|
mock_client.__exit__ = MagicMock(return_value=False)
|
||||||
|
|
||||||
|
with patch('app.routers.llm.httpx.Client', return_value=mock_client):
|
||||||
|
response = client.post(f"/api/llm/{model_id}/test")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["success"] == True
|
||||||
|
|
||||||
|
@patch('httpx.Client')
|
||||||
|
def test_test_llm_model_failure(self, mock_client_class, client, sample_llm_model_data):
|
||||||
|
"""Test testing an LLM model with failed connection"""
|
||||||
|
# Create model first
|
||||||
|
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||||
|
model_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Mock HTTP error
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 401
|
||||||
|
mock_response.text = "Unauthorized"
|
||||||
|
mock_response.raise_for_status = MagicMock(side_effect=Exception("401 Unauthorized"))
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||||
|
mock_client.__exit__ = MagicMock(return_value=False)
|
||||||
|
|
||||||
|
with patch('app.routers.llm.httpx.Client', return_value=mock_client):
|
||||||
|
response = client.post(f"/api/llm/{model_id}/test")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["success"] == False
|
||||||
|
|
||||||
|
def test_different_llm_vendors(self, client):
|
||||||
|
"""Test creating LLM models with different vendors"""
|
||||||
|
vendors = ["OpenAI", "SiliconFlow", "ZhipuAI", "Anthropic"]
|
||||||
|
for vendor in vendors:
|
||||||
|
data = {
|
||||||
|
"id": f"test-{vendor.lower()}",
|
||||||
|
"name": f"Test {vendor}",
|
||||||
|
"vendor": vendor,
|
||||||
|
"type": "text",
|
||||||
|
"base_url": f"https://api.{vendor.lower()}.com/v1",
|
||||||
|
"api_key": "test-key"
|
||||||
|
}
|
||||||
|
response = client.post("/api/llm", json=data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["vendor"] == vendor
|
||||||
|
|
||||||
|
def test_embedding_llm_model(self, client):
|
||||||
|
"""Test creating an embedding LLM model"""
|
||||||
|
data = {
|
||||||
|
"id": "embedding-test",
|
||||||
|
"name": "Embedding Model",
|
||||||
|
"vendor": "OpenAI",
|
||||||
|
"type": "embedding",
|
||||||
|
"base_url": "https://api.openai.com/v1",
|
||||||
|
"api_key": "test-key",
|
||||||
|
"model_name": "text-embedding-3-small"
|
||||||
|
}
|
||||||
|
response = client.post("/api/llm", json=data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["type"] == "embedding"
|
||||||
267
api/tests/test_tools.py
Normal file
267
api/tests/test_tools.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
"""Tests for Tools & Autotest API endpoints"""
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolsAPI:
|
||||||
|
"""Test cases for Tools endpoints"""
|
||||||
|
|
||||||
|
def test_list_available_tools(self, client):
|
||||||
|
"""Test listing all available tools"""
|
||||||
|
response = client.get("/api/tools/list")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "tools" in data
|
||||||
|
# Check for expected tools
|
||||||
|
tools = data["tools"]
|
||||||
|
assert "search" in tools
|
||||||
|
assert "calculator" in tools
|
||||||
|
assert "weather" in tools
|
||||||
|
|
||||||
|
def test_get_tool_detail(self, client):
|
||||||
|
"""Test getting a specific tool's details"""
|
||||||
|
response = client.get("/api/tools/list/search")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == "网络搜索"
|
||||||
|
assert "parameters" in data
|
||||||
|
|
||||||
|
def test_get_tool_detail_not_found(self, client):
|
||||||
|
"""Test getting a non-existent tool"""
|
||||||
|
response = client.get("/api/tools/list/non-existent-tool")
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
def test_health_check(self, client):
|
||||||
|
"""Test health check endpoint"""
|
||||||
|
response = client.get("/api/tools/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "healthy"
|
||||||
|
assert "timestamp" in data
|
||||||
|
assert "tools" in data
|
||||||
|
|
||||||
|
|
||||||
|
class TestAutotestAPI:
|
||||||
|
"""Test cases for Autotest endpoints"""
|
||||||
|
|
||||||
|
def test_autotest_no_models(self, client):
|
||||||
|
"""Test autotest without specifying model IDs"""
|
||||||
|
response = client.post("/api/tools/autotest")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "id" in data
|
||||||
|
assert "tests" in data
|
||||||
|
assert "summary" in data
|
||||||
|
# Should have test failures since no models provided
|
||||||
|
assert data["summary"]["total"] > 0
|
||||||
|
|
||||||
|
def test_autotest_with_llm_model(self, client, sample_llm_model_data):
|
||||||
|
"""Test autotest with an LLM model"""
|
||||||
|
# Create an LLM model first
|
||||||
|
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||||
|
model_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Run autotest
|
||||||
|
response = client.post(f"/api/tools/autotest?llm_model_id={model_id}&test_asr=false")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "tests" in data
|
||||||
|
assert "summary" in data
|
||||||
|
|
||||||
|
def test_autotest_with_asr_model(self, client, sample_asr_model_data):
|
||||||
|
"""Test autotest with an ASR model"""
|
||||||
|
# Create an ASR model first
|
||||||
|
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||||
|
model_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Run autotest
|
||||||
|
response = client.post(f"/api/tools/autotest?asr_model_id={model_id}&test_llm=false")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "tests" in data
|
||||||
|
assert "summary" in data
|
||||||
|
|
||||||
|
def test_autotest_with_both_models(self, client, sample_llm_model_data, sample_asr_model_data):
|
||||||
|
"""Test autotest with both LLM and ASR models"""
|
||||||
|
# Create models
|
||||||
|
llm_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||||
|
llm_id = llm_response.json()["id"]
|
||||||
|
|
||||||
|
asr_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||||
|
asr_id = asr_response.json()["id"]
|
||||||
|
|
||||||
|
# Run autotest
|
||||||
|
response = client.post(
|
||||||
|
f"/api/tools/autotest?llm_model_id={llm_id}&asr_model_id={asr_id}"
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "tests" in data
|
||||||
|
assert "summary" in data
|
||||||
|
|
||||||
|
@patch('httpx.Client')
|
||||||
|
def test_autotest_llm_model_success(self, mock_client_class, client, sample_llm_model_data):
|
||||||
|
"""Test autotest for a specific LLM model with successful connection"""
|
||||||
|
# Create an LLM model first
|
||||||
|
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||||
|
model_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Mock the HTTP response for successful connection
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"choices": [{"message": {"content": "OK"}}]
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
mock_response.iter_bytes = MagicMock(return_value=[b'chunk1', b'chunk2'])
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||||
|
mock_client.__exit__ = MagicMock(return_value=False)
|
||||||
|
|
||||||
|
with patch('app.routers.tools.httpx.Client', return_value=mock_client):
|
||||||
|
response = client.post(f"/api/tools/autotest/llm/{model_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "tests" in data
|
||||||
|
assert "summary" in data
|
||||||
|
|
||||||
|
@patch('httpx.Client')
|
||||||
|
def test_autotest_asr_model_success(self, mock_client_class, client, sample_asr_model_data):
|
||||||
|
"""Test autotest for a specific ASR model with successful connection"""
|
||||||
|
# Create an ASR model first
|
||||||
|
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||||
|
model_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Mock the HTTP response for successful connection
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
mock_client.get.return_value = mock_response
|
||||||
|
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||||
|
mock_client.__exit__ = MagicMock(return_value=False)
|
||||||
|
|
||||||
|
with patch('app.routers.tools.httpx.Client', return_value=mock_client):
|
||||||
|
response = client.post(f"/api/tools/autotest/asr/{model_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "tests" in data
|
||||||
|
assert "summary" in data
|
||||||
|
|
||||||
|
def test_autotest_llm_model_not_found(self, client):
|
||||||
|
"""Test autotest for a non-existent LLM model"""
|
||||||
|
response = client.post("/api/tools/autotest/llm/non-existent-id")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
# Should have a failure test
|
||||||
|
assert any(not t["passed"] for t in data["tests"])
|
||||||
|
|
||||||
|
def test_autotest_asr_model_not_found(self, client):
|
||||||
|
"""Test autotest for a non-existent ASR model"""
|
||||||
|
response = client.post("/api/tools/autotest/asr/non-existent-id")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
# Should have a failure test
|
||||||
|
assert any(not t["passed"] for t in data["tests"])
|
||||||
|
|
||||||
|
@patch('httpx.Client')
|
||||||
|
def test_test_message_success(self, mock_client_class, client, sample_llm_model_data):
|
||||||
|
"""Test sending a test message to an LLM model"""
|
||||||
|
# Create an LLM model first
|
||||||
|
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||||
|
model_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Mock the HTTP response
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"choices": [{"message": {"content": "Hello! This is a test reply."}}],
|
||||||
|
"usage": {"total_tokens": 10}
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||||
|
mock_client.__exit__ = MagicMock(return_value=False)
|
||||||
|
|
||||||
|
with patch('app.routers.tools.httpx.Client', return_value=mock_client):
|
||||||
|
response = client.post(
|
||||||
|
f"/api/tools/test-message?llm_model_id={model_id}",
|
||||||
|
json={"message": "Hello!"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["success"] == True
|
||||||
|
assert "reply" in data
|
||||||
|
|
||||||
|
def test_test_message_model_not_found(self, client):
|
||||||
|
"""Test sending a test message to a non-existent model"""
|
||||||
|
response = client.post(
|
||||||
|
"/api/tools/test-message?llm_model_id=non-existent",
|
||||||
|
json={"message": "Hello!"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
def test_autotest_result_structure(self, client):
|
||||||
|
"""Test that autotest results have the correct structure"""
|
||||||
|
response = client.post("/api/tools/autotest")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Check required fields
|
||||||
|
assert "id" in data
|
||||||
|
assert "started_at" in data
|
||||||
|
assert "duration_ms" in data
|
||||||
|
assert "tests" in data
|
||||||
|
assert "summary" in data
|
||||||
|
|
||||||
|
# Check summary structure
|
||||||
|
assert "passed" in data["summary"]
|
||||||
|
assert "failed" in data["summary"]
|
||||||
|
assert "total" in data["summary"]
|
||||||
|
|
||||||
|
# Check test structure
|
||||||
|
if data["tests"]:
|
||||||
|
test = data["tests"][0]
|
||||||
|
assert "name" in test
|
||||||
|
assert "passed" in test
|
||||||
|
assert "message" in test
|
||||||
|
assert "duration_ms" in test
|
||||||
|
|
||||||
|
def test_tools_have_required_fields(self, client):
|
||||||
|
"""Test that all tools have required fields"""
|
||||||
|
response = client.get("/api/tools/list")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
for tool_id, tool in data["tools"].items():
|
||||||
|
assert "name" in tool
|
||||||
|
assert "description" in tool
|
||||||
|
assert "parameters" in tool
|
||||||
|
|
||||||
|
# Check parameters structure
|
||||||
|
params = tool["parameters"]
|
||||||
|
assert "type" in params
|
||||||
|
assert "properties" in params
|
||||||
|
|
||||||
|
def test_calculator_tool_parameters(self, client):
|
||||||
|
"""Test calculator tool has correct parameters"""
|
||||||
|
response = client.get("/api/tools/list/calculator")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert data["name"] == "计算器"
|
||||||
|
assert "expression" in data["parameters"]["properties"]
|
||||||
|
assert "required" in data["parameters"]
|
||||||
|
assert "expression" in data["parameters"]["required"]
|
||||||
|
|
||||||
|
def test_translate_tool_parameters(self, client):
|
||||||
|
"""Test translate tool has correct parameters"""
|
||||||
|
response = client.get("/api/tools/list/translate")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert data["name"] == "翻译"
|
||||||
|
assert "text" in data["parameters"]["properties"]
|
||||||
|
assert "target_lang" in data["parameters"]["properties"]
|
||||||
132
api/tests/test_voices.py
Normal file
132
api/tests/test_voices.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
"""Tests for Voice API endpoints"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class TestVoiceAPI:
|
||||||
|
"""Test cases for Voice endpoints"""
|
||||||
|
|
||||||
|
def test_get_voices_empty(self, client):
|
||||||
|
"""Test getting voices when database is empty"""
|
||||||
|
response = client.get("/api/voices")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "total" in data
|
||||||
|
assert "list" in data
|
||||||
|
|
||||||
|
def test_create_voice(self, client, sample_voice_data):
|
||||||
|
"""Test creating a new voice"""
|
||||||
|
response = client.post("/api/voices", json=sample_voice_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == sample_voice_data["name"]
|
||||||
|
assert data["vendor"] == sample_voice_data["vendor"]
|
||||||
|
assert data["gender"] == sample_voice_data["gender"]
|
||||||
|
assert data["language"] == sample_voice_data["language"]
|
||||||
|
assert "id" in data
|
||||||
|
|
||||||
|
def test_create_voice_minimal(self, client):
|
||||||
|
"""Test creating a voice with minimal data"""
|
||||||
|
data = {
|
||||||
|
"name": "Minimal Voice",
|
||||||
|
"vendor": "Test",
|
||||||
|
"gender": "Male",
|
||||||
|
"language": "en",
|
||||||
|
"description": ""
|
||||||
|
}
|
||||||
|
response = client.post("/api/voices", json=data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_get_voice_by_id(self, client, sample_voice_data):
|
||||||
|
"""Test getting a specific voice by ID"""
|
||||||
|
# Create first
|
||||||
|
create_response = client.post("/api/voices", json=sample_voice_data)
|
||||||
|
voice_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Get by ID
|
||||||
|
response = client.get(f"/api/voices/{voice_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] == voice_id
|
||||||
|
assert data["name"] == sample_voice_data["name"]
|
||||||
|
|
||||||
|
def test_get_voice_not_found(self, client):
|
||||||
|
"""Test getting a non-existent voice"""
|
||||||
|
response = client.get("/api/voices/non-existent-id")
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
def test_update_voice(self, client, sample_voice_data):
|
||||||
|
"""Test updating a voice"""
|
||||||
|
# Create first
|
||||||
|
create_response = client.post("/api/voices", json=sample_voice_data)
|
||||||
|
voice_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Update
|
||||||
|
update_data = {"name": "Updated Voice", "speed": 1.5}
|
||||||
|
response = client.put(f"/api/voices/{voice_id}", json=update_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == "Updated Voice"
|
||||||
|
assert data["speed"] == 1.5
|
||||||
|
|
||||||
|
def test_delete_voice(self, client, sample_voice_data):
|
||||||
|
"""Test deleting a voice"""
|
||||||
|
# Create first
|
||||||
|
create_response = client.post("/api/voices", json=sample_voice_data)
|
||||||
|
voice_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Delete
|
||||||
|
response = client.delete(f"/api/voices/{voice_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Verify deleted
|
||||||
|
get_response = client.get(f"/api/voices/{voice_id}")
|
||||||
|
assert get_response.status_code == 404
|
||||||
|
|
||||||
|
def test_list_voices_with_pagination(self, client, sample_voice_data):
|
||||||
|
"""Test listing voices with pagination"""
|
||||||
|
# Create multiple voices
|
||||||
|
for i in range(3):
|
||||||
|
data = sample_voice_data.copy()
|
||||||
|
data["name"] = f"Voice {i}"
|
||||||
|
client.post("/api/voices", json=data)
|
||||||
|
|
||||||
|
# Test pagination
|
||||||
|
response = client.get("/api/voices?page=1&limit=2")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 3
|
||||||
|
assert len(data["list"]) == 2
|
||||||
|
|
||||||
|
def test_filter_voices_by_vendor(self, client, sample_voice_data):
|
||||||
|
"""Test filtering voices by vendor"""
|
||||||
|
# Create voice with specific vendor
|
||||||
|
sample_voice_data["vendor"] = "FilterTestVendor"
|
||||||
|
client.post("/api/voices", json=sample_voice_data)
|
||||||
|
|
||||||
|
response = client.get("/api/voices?vendor=FilterTestVendor")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
for voice in data["list"]:
|
||||||
|
assert voice["vendor"] == "FilterTestVendor"
|
||||||
|
|
||||||
|
def test_filter_voices_by_language(self, client, sample_voice_data):
|
||||||
|
"""Test filtering voices by language"""
|
||||||
|
sample_voice_data["language"] = "en"
|
||||||
|
client.post("/api/voices", json=sample_voice_data)
|
||||||
|
|
||||||
|
response = client.get("/api/voices?language=en")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
for voice in data["list"]:
|
||||||
|
assert voice["language"] == "en"
|
||||||
|
|
||||||
|
def test_filter_voices_by_gender(self, client, sample_voice_data):
|
||||||
|
"""Test filtering voices by gender"""
|
||||||
|
sample_voice_data["gender"] = "Female"
|
||||||
|
client.post("/api/voices", json=sample_voice_data)
|
||||||
|
|
||||||
|
response = client.get("/api/voices?gender=Female")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
for voice in data["list"]:
|
||||||
|
assert voice["gender"] == "Female"
|
||||||
Reference in New Issue
Block a user