303 lines
12 KiB
Python
303 lines
12 KiB
Python
"""Tests for LLM Model API endpoints"""
|
|
import pytest
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
|
|
class TestLLMModelAPI:
|
|
"""Test cases for LLM Model endpoints"""
|
|
|
|
def test_get_llm_models_empty(self, client):
|
|
"""Test getting LLM models when database is empty"""
|
|
response = client.get("/api/llm")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "total" in data
|
|
assert "list" in data
|
|
assert data["total"] == 0
|
|
|
|
def test_create_llm_model(self, client, sample_llm_model_data):
|
|
"""Test creating a new LLM model"""
|
|
response = client.post("/api/llm", json=sample_llm_model_data)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["name"] == sample_llm_model_data["name"]
|
|
assert data["vendor"] == sample_llm_model_data["vendor"]
|
|
assert data["type"] == sample_llm_model_data["type"]
|
|
assert data["base_url"] == sample_llm_model_data["base_url"]
|
|
assert "id" in data
|
|
|
|
def test_create_llm_model_minimal(self, client):
|
|
"""Test creating an LLM model with minimal required data"""
|
|
data = {
|
|
"name": "Minimal LLM",
|
|
"vendor": "Test",
|
|
"type": "text",
|
|
"base_url": "https://api.test.com",
|
|
"api_key": "test-key"
|
|
}
|
|
response = client.post("/api/llm", json=data)
|
|
assert response.status_code == 200
|
|
assert response.json()["name"] == "Minimal LLM"
|
|
|
|
def test_get_llm_model_by_id(self, client, sample_llm_model_data):
|
|
"""Test getting a specific LLM model by ID"""
|
|
# Create first
|
|
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
|
model_id = create_response.json()["id"]
|
|
|
|
# Get by ID
|
|
response = client.get(f"/api/llm/{model_id}")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["id"] == model_id
|
|
assert data["name"] == sample_llm_model_data["name"]
|
|
|
|
def test_get_llm_model_not_found(self, client):
|
|
"""Test getting a non-existent LLM model"""
|
|
response = client.get("/api/llm/non-existent-id")
|
|
assert response.status_code == 404
|
|
|
|
def test_update_llm_model(self, client, sample_llm_model_data):
|
|
"""Test updating an LLM model"""
|
|
# Create first
|
|
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
|
model_id = create_response.json()["id"]
|
|
|
|
# Update
|
|
update_data = {
|
|
"name": "Updated LLM Model",
|
|
"vendor": "SiliconFlow",
|
|
"type": "embedding",
|
|
"temperature": 0.5,
|
|
"context_length": 8192
|
|
}
|
|
response = client.put(f"/api/llm/{model_id}", json=update_data)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["name"] == "Updated LLM Model"
|
|
assert data["vendor"] == "SiliconFlow"
|
|
assert data["type"] == "embedding"
|
|
assert data["temperature"] == 0.5
|
|
assert data["context_length"] == 8192
|
|
|
|
def test_delete_llm_model(self, client, sample_llm_model_data):
|
|
"""Test deleting an LLM model"""
|
|
# Create first
|
|
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
|
model_id = create_response.json()["id"]
|
|
|
|
# Delete
|
|
response = client.delete(f"/api/llm/{model_id}")
|
|
assert response.status_code == 200
|
|
|
|
# Verify deleted
|
|
get_response = client.get(f"/api/llm/{model_id}")
|
|
assert get_response.status_code == 404
|
|
|
|
def test_list_llm_models_with_pagination(self, client, sample_llm_model_data):
|
|
"""Test listing LLM models with pagination"""
|
|
# Create multiple models
|
|
for i in range(3):
|
|
data = sample_llm_model_data.copy()
|
|
data["id"] = f"test-llm-{i}"
|
|
data["name"] = f"LLM Model {i}"
|
|
client.post("/api/llm", json=data)
|
|
|
|
# Test pagination
|
|
response = client.get("/api/llm?page=1&limit=2")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["total"] == 3
|
|
assert len(data["list"]) == 2
|
|
|
|
def test_filter_llm_models_by_type(self, client, sample_llm_model_data):
|
|
"""Test filtering LLM models by type"""
|
|
# Create models with different types
|
|
for i, model_type in enumerate(["text", "embedding", "rerank"]):
|
|
data = sample_llm_model_data.copy()
|
|
data["id"] = f"test-llm-{model_type}"
|
|
data["name"] = f"LLM {model_type}"
|
|
data["type"] = model_type
|
|
client.post("/api/llm", json=data)
|
|
|
|
# Filter by type
|
|
response = client.get("/api/llm?model_type=text")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["total"] >= 1
|
|
for model in data["list"]:
|
|
assert model["type"] == "text"
|
|
|
|
def test_filter_llm_models_by_enabled(self, client, sample_llm_model_data):
|
|
"""Test filtering LLM models by enabled status"""
|
|
# Create enabled and disabled models
|
|
data = sample_llm_model_data.copy()
|
|
data["id"] = "test-llm-enabled"
|
|
data["name"] = "Enabled LLM"
|
|
data["enabled"] = True
|
|
client.post("/api/llm", json=data)
|
|
|
|
data["id"] = "test-llm-disabled"
|
|
data["name"] = "Disabled LLM"
|
|
data["enabled"] = False
|
|
client.post("/api/llm", json=data)
|
|
|
|
# Filter by enabled
|
|
response = client.get("/api/llm?enabled=true")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
for model in data["list"]:
|
|
assert model["enabled"] == True
|
|
|
|
def test_create_llm_model_with_all_fields(self, client):
|
|
"""Test creating an LLM model with all fields"""
|
|
data = {
|
|
"id": "full-llm",
|
|
"name": "Full LLM Model",
|
|
"vendor": "OpenAI",
|
|
"type": "text",
|
|
"base_url": "https://api.openai.com/v1",
|
|
"api_key": "sk-test",
|
|
"model_name": "gpt-4",
|
|
"temperature": 0.8,
|
|
"context_length": 16384,
|
|
"enabled": True
|
|
}
|
|
response = client.post("/api/llm", json=data)
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
assert result["name"] == "Full LLM Model"
|
|
assert result["temperature"] == 0.8
|
|
assert result["context_length"] == 16384
|
|
|
|
@patch('httpx.Client')
|
|
def test_test_llm_model_success(self, mock_client_class, client, sample_llm_model_data):
|
|
"""Test testing an LLM model with successful connection"""
|
|
# Create model first
|
|
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
|
model_id = create_response.json()["id"]
|
|
|
|
# Mock the HTTP response
|
|
mock_client = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"choices": [{"message": {"content": "OK"}}]
|
|
}
|
|
mock_response.raise_for_status = MagicMock()
|
|
mock_client.post.return_value = mock_response
|
|
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
|
mock_client.__exit__ = MagicMock(return_value=False)
|
|
|
|
with patch('app.routers.llm.httpx.Client', return_value=mock_client):
|
|
response = client.post(f"/api/llm/{model_id}/test")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] == True
|
|
|
|
@patch('httpx.Client')
|
|
def test_test_llm_model_failure(self, mock_client_class, client, sample_llm_model_data):
|
|
"""Test testing an LLM model with failed connection"""
|
|
# Create model first
|
|
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
|
model_id = create_response.json()["id"]
|
|
|
|
# Mock HTTP error
|
|
mock_client = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 401
|
|
mock_response.text = "Unauthorized"
|
|
mock_response.raise_for_status = MagicMock(side_effect=Exception("401 Unauthorized"))
|
|
mock_client.post.return_value = mock_response
|
|
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
|
mock_client.__exit__ = MagicMock(return_value=False)
|
|
|
|
with patch('app.routers.llm.httpx.Client', return_value=mock_client):
|
|
response = client.post(f"/api/llm/{model_id}/test")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] == False
|
|
|
|
def test_different_llm_vendors(self, client):
|
|
"""Test creating LLM models with different vendors"""
|
|
vendors = ["OpenAI", "SiliconFlow", "ZhipuAI", "Anthropic"]
|
|
for vendor in vendors:
|
|
data = {
|
|
"id": f"test-{vendor.lower()}",
|
|
"name": f"Test {vendor}",
|
|
"vendor": vendor,
|
|
"type": "text",
|
|
"base_url": f"https://api.{vendor.lower()}.com/v1",
|
|
"api_key": "test-key"
|
|
}
|
|
response = client.post("/api/llm", json=data)
|
|
assert response.status_code == 200
|
|
assert response.json()["vendor"] == vendor
|
|
|
|
def test_embedding_llm_model(self, client):
|
|
"""Test creating an embedding LLM model"""
|
|
data = {
|
|
"id": "embedding-test",
|
|
"name": "Embedding Model",
|
|
"vendor": "OpenAI",
|
|
"type": "embedding",
|
|
"base_url": "https://api.openai.com/v1",
|
|
"api_key": "test-key",
|
|
"model_name": "text-embedding-3-small"
|
|
}
|
|
response = client.post("/api/llm", json=data)
|
|
assert response.status_code == 200
|
|
assert response.json()["type"] == "embedding"
|
|
|
|
def test_preview_llm_model_success(self, client, sample_llm_model_data, monkeypatch):
|
|
"""Test LLM preview endpoint returns model reply."""
|
|
from app.routers import llm as llm_router
|
|
|
|
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
|
model_id = create_response.json()["id"]
|
|
|
|
class DummyResponse:
|
|
status_code = 200
|
|
|
|
def json(self):
|
|
return {
|
|
"choices": [{"message": {"content": "Preview OK"}}],
|
|
"usage": {"prompt_tokens": 10, "completion_tokens": 2, "total_tokens": 12}
|
|
}
|
|
|
|
@property
|
|
def text(self):
|
|
return '{"ok":true}'
|
|
|
|
class DummyClient:
|
|
def __init__(self, *args, **kwargs):
|
|
pass
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb):
|
|
return False
|
|
|
|
def post(self, url, json=None, headers=None):
|
|
assert url.endswith("/chat/completions")
|
|
assert headers["Authorization"] == f"Bearer {sample_llm_model_data['api_key']}"
|
|
assert json["messages"][0]["role"] == "user"
|
|
return DummyResponse()
|
|
|
|
monkeypatch.setattr(llm_router.httpx, "Client", DummyClient)
|
|
|
|
response = client.post(f"/api/llm/{model_id}/preview", json={"message": "hello"})
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert data["reply"] == "Preview OK"
|
|
|
|
def test_preview_llm_model_reject_empty_message(self, client, sample_llm_model_data):
|
|
"""Test LLM preview endpoint validates message."""
|
|
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
|
model_id = create_response.json()["id"]
|
|
|
|
response = client.post(f"/api/llm/{model_id}/preview", json={"message": " "})
|
|
assert response.status_code == 400
|