Update backend api
This commit is contained in:
@@ -100,3 +100,38 @@ def sample_call_record_data():
|
||||
"assistant_id": None,
|
||||
"source": "debug"
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_llm_model_data():
|
||||
"""Sample LLM model data for testing"""
|
||||
return {
|
||||
"id": "test-llm-001",
|
||||
"name": "Test LLM Model",
|
||||
"vendor": "TestVendor",
|
||||
"type": "text",
|
||||
"base_url": "https://api.test.com/v1",
|
||||
"api_key": "test-api-key",
|
||||
"model_name": "test-model",
|
||||
"temperature": 0.7,
|
||||
"context_length": 4096,
|
||||
"enabled": True
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_asr_model_data():
|
||||
"""Sample ASR model data for testing"""
|
||||
return {
|
||||
"id": "test-asr-001",
|
||||
"name": "Test ASR Model",
|
||||
"vendor": "TestVendor",
|
||||
"language": "zh",
|
||||
"base_url": "https://api.test.com/v1",
|
||||
"api_key": "test-api-key",
|
||||
"model_name": "paraformer-v2",
|
||||
"hotwords": ["测试", "语音"],
|
||||
"enable_punctuation": True,
|
||||
"enable_normalization": True,
|
||||
"enabled": True
|
||||
}
|
||||
|
||||
289
api/tests/test_asr.py
Normal file
289
api/tests/test_asr.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""Tests for ASR Model API endpoints"""
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
class TestASRModelAPI:
|
||||
"""Test cases for ASR Model endpoints"""
|
||||
|
||||
def test_get_asr_models_empty(self, client):
|
||||
"""Test getting ASR models when database is empty"""
|
||||
response = client.get("/api/asr")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "total" in data
|
||||
assert "list" in data
|
||||
assert data["total"] == 0
|
||||
|
||||
def test_create_asr_model(self, client, sample_asr_model_data):
|
||||
"""Test creating a new ASR model"""
|
||||
response = client.post("/api/asr", json=sample_asr_model_data)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == sample_asr_model_data["name"]
|
||||
assert data["vendor"] == sample_asr_model_data["vendor"]
|
||||
assert data["language"] == sample_asr_model_data["language"]
|
||||
assert "id" in data
|
||||
|
||||
def test_create_asr_model_minimal(self, client):
|
||||
"""Test creating an ASR model with minimal required data"""
|
||||
data = {
|
||||
"name": "Minimal ASR",
|
||||
"vendor": "Test",
|
||||
"language": "zh",
|
||||
"base_url": "https://api.test.com",
|
||||
"api_key": "test-key"
|
||||
}
|
||||
response = client.post("/api/asr", json=data)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "Minimal ASR"
|
||||
|
||||
def test_get_asr_model_by_id(self, client, sample_asr_model_data):
|
||||
"""Test getting a specific ASR model by ID"""
|
||||
# Create first
|
||||
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Get by ID
|
||||
response = client.get(f"/api/asr/{model_id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == model_id
|
||||
assert data["name"] == sample_asr_model_data["name"]
|
||||
|
||||
def test_get_asr_model_not_found(self, client):
|
||||
"""Test getting a non-existent ASR model"""
|
||||
response = client.get("/api/asr/non-existent-id")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_update_asr_model(self, client, sample_asr_model_data):
|
||||
"""Test updating an ASR model"""
|
||||
# Create first
|
||||
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Update
|
||||
update_data = {
|
||||
"name": "Updated ASR Model",
|
||||
"language": "en",
|
||||
"enable_punctuation": False
|
||||
}
|
||||
response = client.put(f"/api/asr/{model_id}", json=update_data)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "Updated ASR Model"
|
||||
assert data["language"] == "en"
|
||||
assert data["enable_punctuation"] == False
|
||||
|
||||
def test_delete_asr_model(self, client, sample_asr_model_data):
|
||||
"""Test deleting an ASR model"""
|
||||
# Create first
|
||||
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Delete
|
||||
response = client.delete(f"/api/asr/{model_id}")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify deleted
|
||||
get_response = client.get(f"/api/asr/{model_id}")
|
||||
assert get_response.status_code == 404
|
||||
|
||||
def test_list_asr_models_with_pagination(self, client, sample_asr_model_data):
|
||||
"""Test listing ASR models with pagination"""
|
||||
# Create multiple models
|
||||
for i in range(3):
|
||||
data = sample_asr_model_data.copy()
|
||||
data["id"] = f"test-asr-{i}"
|
||||
data["name"] = f"ASR Model {i}"
|
||||
client.post("/api/asr", json=data)
|
||||
|
||||
# Test pagination
|
||||
response = client.get("/api/asr?page=1&limit=2")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 3
|
||||
assert len(data["list"]) == 2
|
||||
|
||||
def test_filter_asr_models_by_language(self, client, sample_asr_model_data):
|
||||
"""Test filtering ASR models by language"""
|
||||
# Create models with different languages
|
||||
for i, lang in enumerate(["zh", "en", "Multi-lingual"]):
|
||||
data = sample_asr_model_data.copy()
|
||||
data["id"] = f"test-asr-{lang}"
|
||||
data["name"] = f"ASR {lang}"
|
||||
data["language"] = lang
|
||||
client.post("/api/asr", json=data)
|
||||
|
||||
# Filter by language
|
||||
response = client.get("/api/asr?language=zh")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] >= 1
|
||||
for model in data["list"]:
|
||||
assert model["language"] == "zh"
|
||||
|
||||
def test_filter_asr_models_by_enabled(self, client, sample_asr_model_data):
|
||||
"""Test filtering ASR models by enabled status"""
|
||||
# Create enabled and disabled models
|
||||
data = sample_asr_model_data.copy()
|
||||
data["id"] = "test-asr-enabled"
|
||||
data["name"] = "Enabled ASR"
|
||||
data["enabled"] = True
|
||||
client.post("/api/asr", json=data)
|
||||
|
||||
data["id"] = "test-asr-disabled"
|
||||
data["name"] = "Disabled ASR"
|
||||
data["enabled"] = False
|
||||
client.post("/api/asr", json=data)
|
||||
|
||||
# Filter by enabled
|
||||
response = client.get("/api/asr?enabled=true")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
for model in data["list"]:
|
||||
assert model["enabled"] == True
|
||||
|
||||
def test_create_asr_model_with_hotwords(self, client):
|
||||
"""Test creating an ASR model with hotwords"""
|
||||
data = {
|
||||
"id": "asr-hotwords",
|
||||
"name": "ASR with Hotwords",
|
||||
"vendor": "SiliconFlow",
|
||||
"language": "zh",
|
||||
"base_url": "https://api.siliconflow.cn/v1",
|
||||
"api_key": "test-key",
|
||||
"model_name": "paraformer-v2",
|
||||
"hotwords": ["你好", "查询", "帮助"],
|
||||
"enable_punctuation": True,
|
||||
"enable_normalization": True
|
||||
}
|
||||
response = client.post("/api/asr", json=data)
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
assert result["hotwords"] == ["你好", "查询", "帮助"]
|
||||
|
||||
def test_create_asr_model_with_all_fields(self, client):
|
||||
"""Test creating an ASR model with all fields"""
|
||||
data = {
|
||||
"id": "full-asr",
|
||||
"name": "Full ASR Model",
|
||||
"vendor": "SiliconFlow",
|
||||
"language": "zh",
|
||||
"base_url": "https://api.siliconflow.cn/v1",
|
||||
"api_key": "sk-test",
|
||||
"model_name": "paraformer-v2",
|
||||
"hotwords": ["测试"],
|
||||
"enable_punctuation": True,
|
||||
"enable_normalization": True,
|
||||
"enabled": True
|
||||
}
|
||||
response = client.post("/api/asr", json=data)
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
assert result["name"] == "Full ASR Model"
|
||||
assert result["enable_punctuation"] == True
|
||||
assert result["enable_normalization"] == True
|
||||
|
||||
@patch('httpx.Client')
|
||||
def test_test_asr_model_siliconflow(self, mock_client_class, client, sample_asr_model_data):
|
||||
"""Test testing an ASR model with SiliconFlow vendor"""
|
||||
# Create model first
|
||||
sample_asr_model_data["vendor"] = "SiliconFlow"
|
||||
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Mock the HTTP response
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"results": [{"transcript": "测试文本", "language": "zh"}]
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_client.get.return_value = mock_response
|
||||
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch('app.routers.asr.httpx.Client', return_value=mock_client):
|
||||
response = client.post(f"/api/asr/{model_id}/test")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] == True
|
||||
|
||||
@patch('httpx.Client')
|
||||
def test_test_asr_model_openai(self, mock_client_class, client, sample_asr_model_data):
|
||||
"""Test testing an ASR model with OpenAI vendor"""
|
||||
# Create model with OpenAI vendor
|
||||
sample_asr_model_data["vendor"] = "OpenAI"
|
||||
sample_asr_model_data["id"] = "test-asr-openai"
|
||||
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Mock the HTTP response
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"text": "Test transcript"}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_client.get.return_value = mock_response
|
||||
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch('app.routers.asr.httpx.Client', return_value=mock_client):
|
||||
response = client.post(f"/api/asr/{model_id}/test")
|
||||
assert response.status_code == 200
|
||||
|
||||
@patch('httpx.Client')
|
||||
def test_test_asr_model_failure(self, mock_client_class, client, sample_asr_model_data):
|
||||
"""Test testing an ASR model with failed connection"""
|
||||
# Create model first
|
||||
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Mock HTTP error
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = "Unauthorized"
|
||||
mock_response.raise_for_status = MagicMock(side_effect=Exception("401 Unauthorized"))
|
||||
mock_client.get.return_value = mock_response
|
||||
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch('app.routers.asr.httpx.Client', return_value=mock_client):
|
||||
response = client.post(f"/api/asr/{model_id}/test")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] == False
|
||||
|
||||
def test_different_asr_languages(self, client):
|
||||
"""Test creating ASR models with different languages"""
|
||||
for lang in ["zh", "en", "Multi-lingual"]:
|
||||
data = {
|
||||
"id": f"asr-lang-{lang}",
|
||||
"name": f"ASR {lang}",
|
||||
"vendor": "SiliconFlow",
|
||||
"language": lang,
|
||||
"base_url": "https://api.siliconflow.cn/v1",
|
||||
"api_key": "test-key"
|
||||
}
|
||||
response = client.post("/api/asr", json=data)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["language"] == lang
|
||||
|
||||
def test_different_asr_vendors(self, client):
|
||||
"""Test creating ASR models with different vendors"""
|
||||
vendors = ["SiliconFlow", "OpenAI", "Azure"]
|
||||
for vendor in vendors:
|
||||
data = {
|
||||
"id": f"asr-vendor-{vendor.lower()}",
|
||||
"name": f"ASR {vendor}",
|
||||
"vendor": vendor,
|
||||
"language": "zh",
|
||||
"base_url": f"https://api.{vendor.lower()}.com/v1",
|
||||
"api_key": "test-key"
|
||||
}
|
||||
response = client.post("/api/asr", json=data)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["vendor"] == vendor
|
||||
246
api/tests/test_llm.py
Normal file
246
api/tests/test_llm.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""Tests for LLM Model API endpoints"""
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
class TestLLMModelAPI:
|
||||
"""Test cases for LLM Model endpoints"""
|
||||
|
||||
def test_get_llm_models_empty(self, client):
|
||||
"""Test getting LLM models when database is empty"""
|
||||
response = client.get("/api/llm")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "total" in data
|
||||
assert "list" in data
|
||||
assert data["total"] == 0
|
||||
|
||||
def test_create_llm_model(self, client, sample_llm_model_data):
|
||||
"""Test creating a new LLM model"""
|
||||
response = client.post("/api/llm", json=sample_llm_model_data)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == sample_llm_model_data["name"]
|
||||
assert data["vendor"] == sample_llm_model_data["vendor"]
|
||||
assert data["type"] == sample_llm_model_data["type"]
|
||||
assert data["base_url"] == sample_llm_model_data["base_url"]
|
||||
assert "id" in data
|
||||
|
||||
def test_create_llm_model_minimal(self, client):
|
||||
"""Test creating an LLM model with minimal required data"""
|
||||
data = {
|
||||
"name": "Minimal LLM",
|
||||
"vendor": "Test",
|
||||
"type": "text",
|
||||
"base_url": "https://api.test.com",
|
||||
"api_key": "test-key"
|
||||
}
|
||||
response = client.post("/api/llm", json=data)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "Minimal LLM"
|
||||
|
||||
def test_get_llm_model_by_id(self, client, sample_llm_model_data):
|
||||
"""Test getting a specific LLM model by ID"""
|
||||
# Create first
|
||||
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Get by ID
|
||||
response = client.get(f"/api/llm/{model_id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == model_id
|
||||
assert data["name"] == sample_llm_model_data["name"]
|
||||
|
||||
def test_get_llm_model_not_found(self, client):
|
||||
"""Test getting a non-existent LLM model"""
|
||||
response = client.get("/api/llm/non-existent-id")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_update_llm_model(self, client, sample_llm_model_data):
|
||||
"""Test updating an LLM model"""
|
||||
# Create first
|
||||
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Update
|
||||
update_data = {
|
||||
"name": "Updated LLM Model",
|
||||
"temperature": 0.5,
|
||||
"context_length": 8192
|
||||
}
|
||||
response = client.put(f"/api/llm/{model_id}", json=update_data)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "Updated LLM Model"
|
||||
assert data["temperature"] == 0.5
|
||||
assert data["context_length"] == 8192
|
||||
|
||||
def test_delete_llm_model(self, client, sample_llm_model_data):
|
||||
"""Test deleting an LLM model"""
|
||||
# Create first
|
||||
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Delete
|
||||
response = client.delete(f"/api/llm/{model_id}")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify deleted
|
||||
get_response = client.get(f"/api/llm/{model_id}")
|
||||
assert get_response.status_code == 404
|
||||
|
||||
def test_list_llm_models_with_pagination(self, client, sample_llm_model_data):
|
||||
"""Test listing LLM models with pagination"""
|
||||
# Create multiple models
|
||||
for i in range(3):
|
||||
data = sample_llm_model_data.copy()
|
||||
data["id"] = f"test-llm-{i}"
|
||||
data["name"] = f"LLM Model {i}"
|
||||
client.post("/api/llm", json=data)
|
||||
|
||||
# Test pagination
|
||||
response = client.get("/api/llm?page=1&limit=2")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 3
|
||||
assert len(data["list"]) == 2
|
||||
|
||||
def test_filter_llm_models_by_type(self, client, sample_llm_model_data):
|
||||
"""Test filtering LLM models by type"""
|
||||
# Create models with different types
|
||||
for i, model_type in enumerate(["text", "embedding", "rerank"]):
|
||||
data = sample_llm_model_data.copy()
|
||||
data["id"] = f"test-llm-{model_type}"
|
||||
data["name"] = f"LLM {model_type}"
|
||||
data["type"] = model_type
|
||||
client.post("/api/llm", json=data)
|
||||
|
||||
# Filter by type
|
||||
response = client.get("/api/llm?model_type=text")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] >= 1
|
||||
for model in data["list"]:
|
||||
assert model["type"] == "text"
|
||||
|
||||
def test_filter_llm_models_by_enabled(self, client, sample_llm_model_data):
|
||||
"""Test filtering LLM models by enabled status"""
|
||||
# Create enabled and disabled models
|
||||
data = sample_llm_model_data.copy()
|
||||
data["id"] = "test-llm-enabled"
|
||||
data["name"] = "Enabled LLM"
|
||||
data["enabled"] = True
|
||||
client.post("/api/llm", json=data)
|
||||
|
||||
data["id"] = "test-llm-disabled"
|
||||
data["name"] = "Disabled LLM"
|
||||
data["enabled"] = False
|
||||
client.post("/api/llm", json=data)
|
||||
|
||||
# Filter by enabled
|
||||
response = client.get("/api/llm?enabled=true")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
for model in data["list"]:
|
||||
assert model["enabled"] == True
|
||||
|
||||
def test_create_llm_model_with_all_fields(self, client):
|
||||
"""Test creating an LLM model with all fields"""
|
||||
data = {
|
||||
"id": "full-llm",
|
||||
"name": "Full LLM Model",
|
||||
"vendor": "OpenAI",
|
||||
"type": "text",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "sk-test",
|
||||
"model_name": "gpt-4",
|
||||
"temperature": 0.8,
|
||||
"context_length": 16384,
|
||||
"enabled": True
|
||||
}
|
||||
response = client.post("/api/llm", json=data)
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
assert result["name"] == "Full LLM Model"
|
||||
assert result["temperature"] == 0.8
|
||||
assert result["context_length"] == 16384
|
||||
|
||||
@patch('httpx.Client')
|
||||
def test_test_llm_model_success(self, mock_client_class, client, sample_llm_model_data):
|
||||
"""Test testing an LLM model with successful connection"""
|
||||
# Create model first
|
||||
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Mock the HTTP response
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "OK"}}]
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch('app.routers.llm.httpx.Client', return_value=mock_client):
|
||||
response = client.post(f"/api/llm/{model_id}/test")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] == True
|
||||
|
||||
@patch('httpx.Client')
|
||||
def test_test_llm_model_failure(self, mock_client_class, client, sample_llm_model_data):
|
||||
"""Test testing an LLM model with failed connection"""
|
||||
# Create model first
|
||||
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Mock HTTP error
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = "Unauthorized"
|
||||
mock_response.raise_for_status = MagicMock(side_effect=Exception("401 Unauthorized"))
|
||||
mock_client.post.return_value = mock_response
|
||||
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch('app.routers.llm.httpx.Client', return_value=mock_client):
|
||||
response = client.post(f"/api/llm/{model_id}/test")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] == False
|
||||
|
||||
def test_different_llm_vendors(self, client):
|
||||
"""Test creating LLM models with different vendors"""
|
||||
vendors = ["OpenAI", "SiliconFlow", "ZhipuAI", "Anthropic"]
|
||||
for vendor in vendors:
|
||||
data = {
|
||||
"id": f"test-{vendor.lower()}",
|
||||
"name": f"Test {vendor}",
|
||||
"vendor": vendor,
|
||||
"type": "text",
|
||||
"base_url": f"https://api.{vendor.lower()}.com/v1",
|
||||
"api_key": "test-key"
|
||||
}
|
||||
response = client.post("/api/llm", json=data)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["vendor"] == vendor
|
||||
|
||||
def test_embedding_llm_model(self, client):
|
||||
"""Test creating an embedding LLM model"""
|
||||
data = {
|
||||
"id": "embedding-test",
|
||||
"name": "Embedding Model",
|
||||
"vendor": "OpenAI",
|
||||
"type": "embedding",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "test-key",
|
||||
"model_name": "text-embedding-3-small"
|
||||
}
|
||||
response = client.post("/api/llm", json=data)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["type"] == "embedding"
|
||||
267
api/tests/test_tools.py
Normal file
267
api/tests/test_tools.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""Tests for Tools & Autotest API endpoints"""
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
class TestToolsAPI:
|
||||
"""Test cases for Tools endpoints"""
|
||||
|
||||
def test_list_available_tools(self, client):
|
||||
"""Test listing all available tools"""
|
||||
response = client.get("/api/tools/list")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "tools" in data
|
||||
# Check for expected tools
|
||||
tools = data["tools"]
|
||||
assert "search" in tools
|
||||
assert "calculator" in tools
|
||||
assert "weather" in tools
|
||||
|
||||
def test_get_tool_detail(self, client):
|
||||
"""Test getting a specific tool's details"""
|
||||
response = client.get("/api/tools/list/search")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "网络搜索"
|
||||
assert "parameters" in data
|
||||
|
||||
def test_get_tool_detail_not_found(self, client):
|
||||
"""Test getting a non-existent tool"""
|
||||
response = client.get("/api/tools/list/non-existent-tool")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_health_check(self, client):
|
||||
"""Test health check endpoint"""
|
||||
response = client.get("/api/tools/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert "timestamp" in data
|
||||
assert "tools" in data
|
||||
|
||||
|
||||
class TestAutotestAPI:
|
||||
"""Test cases for Autotest endpoints"""
|
||||
|
||||
def test_autotest_no_models(self, client):
|
||||
"""Test autotest without specifying model IDs"""
|
||||
response = client.post("/api/tools/autotest")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "id" in data
|
||||
assert "tests" in data
|
||||
assert "summary" in data
|
||||
# Should have test failures since no models provided
|
||||
assert data["summary"]["total"] > 0
|
||||
|
||||
def test_autotest_with_llm_model(self, client, sample_llm_model_data):
|
||||
"""Test autotest with an LLM model"""
|
||||
# Create an LLM model first
|
||||
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Run autotest
|
||||
response = client.post(f"/api/tools/autotest?llm_model_id={model_id}&test_asr=false")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "tests" in data
|
||||
assert "summary" in data
|
||||
|
||||
def test_autotest_with_asr_model(self, client, sample_asr_model_data):
|
||||
"""Test autotest with an ASR model"""
|
||||
# Create an ASR model first
|
||||
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Run autotest
|
||||
response = client.post(f"/api/tools/autotest?asr_model_id={model_id}&test_llm=false")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "tests" in data
|
||||
assert "summary" in data
|
||||
|
||||
def test_autotest_with_both_models(self, client, sample_llm_model_data, sample_asr_model_data):
|
||||
"""Test autotest with both LLM and ASR models"""
|
||||
# Create models
|
||||
llm_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||
llm_id = llm_response.json()["id"]
|
||||
|
||||
asr_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||
asr_id = asr_response.json()["id"]
|
||||
|
||||
# Run autotest
|
||||
response = client.post(
|
||||
f"/api/tools/autotest?llm_model_id={llm_id}&asr_model_id={asr_id}"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "tests" in data
|
||||
assert "summary" in data
|
||||
|
||||
@patch('httpx.Client')
|
||||
def test_autotest_llm_model_success(self, mock_client_class, client, sample_llm_model_data):
|
||||
"""Test autotest for a specific LLM model with successful connection"""
|
||||
# Create an LLM model first
|
||||
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Mock the HTTP response for successful connection
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "OK"}}]
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.iter_bytes = MagicMock(return_value=[b'chunk1', b'chunk2'])
|
||||
mock_client.post.return_value = mock_response
|
||||
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch('app.routers.tools.httpx.Client', return_value=mock_client):
|
||||
response = client.post(f"/api/tools/autotest/llm/{model_id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "tests" in data
|
||||
assert "summary" in data
|
||||
|
||||
@patch('httpx.Client')
|
||||
def test_autotest_asr_model_success(self, mock_client_class, client, sample_asr_model_data):
|
||||
"""Test autotest for a specific ASR model with successful connection"""
|
||||
# Create an ASR model first
|
||||
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Mock the HTTP response for successful connection
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_client.get.return_value = mock_response
|
||||
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch('app.routers.tools.httpx.Client', return_value=mock_client):
|
||||
response = client.post(f"/api/tools/autotest/asr/{model_id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "tests" in data
|
||||
assert "summary" in data
|
||||
|
||||
def test_autotest_llm_model_not_found(self, client):
|
||||
"""Test autotest for a non-existent LLM model"""
|
||||
response = client.post("/api/tools/autotest/llm/non-existent-id")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Should have a failure test
|
||||
assert any(not t["passed"] for t in data["tests"])
|
||||
|
||||
def test_autotest_asr_model_not_found(self, client):
|
||||
"""Test autotest for a non-existent ASR model"""
|
||||
response = client.post("/api/tools/autotest/asr/non-existent-id")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Should have a failure test
|
||||
assert any(not t["passed"] for t in data["tests"])
|
||||
|
||||
@patch('httpx.Client')
|
||||
def test_test_message_success(self, mock_client_class, client, sample_llm_model_data):
|
||||
"""Test sending a test message to an LLM model"""
|
||||
# Create an LLM model first
|
||||
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
||||
model_id = create_response.json()["id"]
|
||||
|
||||
# Mock the HTTP response
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "Hello! This is a test reply."}}],
|
||||
"usage": {"total_tokens": 10}
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch('app.routers.tools.httpx.Client', return_value=mock_client):
|
||||
response = client.post(
|
||||
f"/api/tools/test-message?llm_model_id={model_id}",
|
||||
json={"message": "Hello!"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] == True
|
||||
assert "reply" in data
|
||||
|
||||
def test_test_message_model_not_found(self, client):
|
||||
"""Test sending a test message to a non-existent model"""
|
||||
response = client.post(
|
||||
"/api/tools/test-message?llm_model_id=non-existent",
|
||||
json={"message": "Hello!"}
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_autotest_result_structure(self, client):
|
||||
"""Test that autotest results have the correct structure"""
|
||||
response = client.post("/api/tools/autotest")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Check required fields
|
||||
assert "id" in data
|
||||
assert "started_at" in data
|
||||
assert "duration_ms" in data
|
||||
assert "tests" in data
|
||||
assert "summary" in data
|
||||
|
||||
# Check summary structure
|
||||
assert "passed" in data["summary"]
|
||||
assert "failed" in data["summary"]
|
||||
assert "total" in data["summary"]
|
||||
|
||||
# Check test structure
|
||||
if data["tests"]:
|
||||
test = data["tests"][0]
|
||||
assert "name" in test
|
||||
assert "passed" in test
|
||||
assert "message" in test
|
||||
assert "duration_ms" in test
|
||||
|
||||
def test_tools_have_required_fields(self, client):
|
||||
"""Test that all tools have required fields"""
|
||||
response = client.get("/api/tools/list")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
for tool_id, tool in data["tools"].items():
|
||||
assert "name" in tool
|
||||
assert "description" in tool
|
||||
assert "parameters" in tool
|
||||
|
||||
# Check parameters structure
|
||||
params = tool["parameters"]
|
||||
assert "type" in params
|
||||
assert "properties" in params
|
||||
|
||||
def test_calculator_tool_parameters(self, client):
|
||||
"""Test calculator tool has correct parameters"""
|
||||
response = client.get("/api/tools/list/calculator")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["name"] == "计算器"
|
||||
assert "expression" in data["parameters"]["properties"]
|
||||
assert "required" in data["parameters"]
|
||||
assert "expression" in data["parameters"]["required"]
|
||||
|
||||
def test_translate_tool_parameters(self, client):
|
||||
"""Test translate tool has correct parameters"""
|
||||
response = client.get("/api/tools/list/translate")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["name"] == "翻译"
|
||||
assert "text" in data["parameters"]["properties"]
|
||||
assert "target_lang" in data["parameters"]["properties"]
|
||||
Reference in New Issue
Block a user