Files
AI-VideoAssistant/api/tests/test_llm.py
2026-02-08 15:52:16 +08:00

247 lines
9.5 KiB
Python

"""Tests for LLM Model API endpoints"""
import pytest
from unittest.mock import patch, MagicMock
class TestLLMModelAPI:
"""Test cases for LLM Model endpoints"""
def test_get_llm_models_empty(self, client):
"""Test getting LLM models when database is empty"""
response = client.get("/api/llm")
assert response.status_code == 200
data = response.json()
assert "total" in data
assert "list" in data
assert data["total"] == 0
def test_create_llm_model(self, client, sample_llm_model_data):
"""Test creating a new LLM model"""
response = client.post("/api/llm", json=sample_llm_model_data)
assert response.status_code == 200
data = response.json()
assert data["name"] == sample_llm_model_data["name"]
assert data["vendor"] == sample_llm_model_data["vendor"]
assert data["type"] == sample_llm_model_data["type"]
assert data["base_url"] == sample_llm_model_data["base_url"]
assert "id" in data
def test_create_llm_model_minimal(self, client):
"""Test creating an LLM model with minimal required data"""
data = {
"name": "Minimal LLM",
"vendor": "Test",
"type": "text",
"base_url": "https://api.test.com",
"api_key": "test-key"
}
response = client.post("/api/llm", json=data)
assert response.status_code == 200
assert response.json()["name"] == "Minimal LLM"
def test_get_llm_model_by_id(self, client, sample_llm_model_data):
"""Test getting a specific LLM model by ID"""
# Create first
create_response = client.post("/api/llm", json=sample_llm_model_data)
model_id = create_response.json()["id"]
# Get by ID
response = client.get(f"/api/llm/{model_id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == model_id
assert data["name"] == sample_llm_model_data["name"]
def test_get_llm_model_not_found(self, client):
"""Test getting a non-existent LLM model"""
response = client.get("/api/llm/non-existent-id")
assert response.status_code == 404
def test_update_llm_model(self, client, sample_llm_model_data):
"""Test updating an LLM model"""
# Create first
create_response = client.post("/api/llm", json=sample_llm_model_data)
model_id = create_response.json()["id"]
# Update
update_data = {
"name": "Updated LLM Model",
"temperature": 0.5,
"context_length": 8192
}
response = client.put(f"/api/llm/{model_id}", json=update_data)
assert response.status_code == 200
data = response.json()
assert data["name"] == "Updated LLM Model"
assert data["temperature"] == 0.5
assert data["context_length"] == 8192
def test_delete_llm_model(self, client, sample_llm_model_data):
"""Test deleting an LLM model"""
# Create first
create_response = client.post("/api/llm", json=sample_llm_model_data)
model_id = create_response.json()["id"]
# Delete
response = client.delete(f"/api/llm/{model_id}")
assert response.status_code == 200
# Verify deleted
get_response = client.get(f"/api/llm/{model_id}")
assert get_response.status_code == 404
def test_list_llm_models_with_pagination(self, client, sample_llm_model_data):
"""Test listing LLM models with pagination"""
# Create multiple models
for i in range(3):
data = sample_llm_model_data.copy()
data["id"] = f"test-llm-{i}"
data["name"] = f"LLM Model {i}"
client.post("/api/llm", json=data)
# Test pagination
response = client.get("/api/llm?page=1&limit=2")
assert response.status_code == 200
data = response.json()
assert data["total"] == 3
assert len(data["list"]) == 2
def test_filter_llm_models_by_type(self, client, sample_llm_model_data):
"""Test filtering LLM models by type"""
# Create models with different types
for i, model_type in enumerate(["text", "embedding", "rerank"]):
data = sample_llm_model_data.copy()
data["id"] = f"test-llm-{model_type}"
data["name"] = f"LLM {model_type}"
data["type"] = model_type
client.post("/api/llm", json=data)
# Filter by type
response = client.get("/api/llm?model_type=text")
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1
for model in data["list"]:
assert model["type"] == "text"
def test_filter_llm_models_by_enabled(self, client, sample_llm_model_data):
"""Test filtering LLM models by enabled status"""
# Create enabled and disabled models
data = sample_llm_model_data.copy()
data["id"] = "test-llm-enabled"
data["name"] = "Enabled LLM"
data["enabled"] = True
client.post("/api/llm", json=data)
data["id"] = "test-llm-disabled"
data["name"] = "Disabled LLM"
data["enabled"] = False
client.post("/api/llm", json=data)
# Filter by enabled
response = client.get("/api/llm?enabled=true")
assert response.status_code == 200
data = response.json()
for model in data["list"]:
assert model["enabled"] == True
def test_create_llm_model_with_all_fields(self, client):
"""Test creating an LLM model with all fields"""
data = {
"id": "full-llm",
"name": "Full LLM Model",
"vendor": "OpenAI",
"type": "text",
"base_url": "https://api.openai.com/v1",
"api_key": "sk-test",
"model_name": "gpt-4",
"temperature": 0.8,
"context_length": 16384,
"enabled": True
}
response = client.post("/api/llm", json=data)
assert response.status_code == 200
result = response.json()
assert result["name"] == "Full LLM Model"
assert result["temperature"] == 0.8
assert result["context_length"] == 16384
@patch('httpx.Client')
def test_test_llm_model_success(self, mock_client_class, client, sample_llm_model_data):
"""Test testing an LLM model with successful connection"""
# Create model first
create_response = client.post("/api/llm", json=sample_llm_model_data)
model_id = create_response.json()["id"]
# Mock the HTTP response
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [{"message": {"content": "OK"}}]
}
mock_response.raise_for_status = MagicMock()
mock_client.post.return_value = mock_response
mock_client.__enter__ = MagicMock(return_value=mock_client)
mock_client.__exit__ = MagicMock(return_value=False)
with patch('app.routers.llm.httpx.Client', return_value=mock_client):
response = client.post(f"/api/llm/{model_id}/test")
assert response.status_code == 200
data = response.json()
assert data["success"] == True
@patch('httpx.Client')
def test_test_llm_model_failure(self, mock_client_class, client, sample_llm_model_data):
"""Test testing an LLM model with failed connection"""
# Create model first
create_response = client.post("/api/llm", json=sample_llm_model_data)
model_id = create_response.json()["id"]
# Mock HTTP error
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 401
mock_response.text = "Unauthorized"
mock_response.raise_for_status = MagicMock(side_effect=Exception("401 Unauthorized"))
mock_client.post.return_value = mock_response
mock_client.__enter__ = MagicMock(return_value=mock_client)
mock_client.__exit__ = MagicMock(return_value=False)
with patch('app.routers.llm.httpx.Client', return_value=mock_client):
response = client.post(f"/api/llm/{model_id}/test")
assert response.status_code == 200
data = response.json()
assert data["success"] == False
def test_different_llm_vendors(self, client):
"""Test creating LLM models with different vendors"""
vendors = ["OpenAI", "SiliconFlow", "ZhipuAI", "Anthropic"]
for vendor in vendors:
data = {
"id": f"test-{vendor.lower()}",
"name": f"Test {vendor}",
"vendor": vendor,
"type": "text",
"base_url": f"https://api.{vendor.lower()}.com/v1",
"api_key": "test-key"
}
response = client.post("/api/llm", json=data)
assert response.status_code == 200
assert response.json()["vendor"] == vendor
def test_embedding_llm_model(self, client):
"""Test creating an embedding LLM model"""
data = {
"id": "embedding-test",
"name": "Embedding Model",
"vendor": "OpenAI",
"type": "embedding",
"base_url": "https://api.openai.com/v1",
"api_key": "test-key",
"model_name": "text-embedding-3-small"
}
response = client.post("/api/llm", json=data)
assert response.status_code == 200
assert response.json()["type"] == "embedding"