346 lines
14 KiB
Python
346 lines
14 KiB
Python
"""Tests for Tools & Autotest API endpoints"""
|
|
import pytest
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
|
|
class TestToolsAPI:
|
|
"""Test cases for Tools endpoints"""
|
|
|
|
def test_list_available_tools(self, client):
|
|
"""Test listing all available tools"""
|
|
response = client.get("/api/tools/list")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "tools" in data
|
|
# Check for expected tools
|
|
tools = data["tools"]
|
|
assert "calculator" in tools
|
|
assert "code_interpreter" in tools
|
|
assert "current_time" in tools
|
|
assert "turn_on_camera" in tools
|
|
assert "turn_off_camera" in tools
|
|
assert "increase_volume" in tools
|
|
assert "decrease_volume" in tools
|
|
assert "calculator" in tools
|
|
|
|
def test_get_tool_detail(self, client):
|
|
"""Test getting a specific tool's details"""
|
|
response = client.get("/api/tools/list/calculator")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["name"] == "计算器"
|
|
assert "parameters" in data
|
|
|
|
def test_get_tool_detail_not_found(self, client):
|
|
"""Test getting a non-existent tool"""
|
|
response = client.get("/api/tools/list/non-existent-tool")
|
|
assert response.status_code == 404
|
|
|
|
def test_health_check(self, client):
|
|
"""Test health check endpoint"""
|
|
response = client.get("/api/tools/health")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "healthy"
|
|
assert "timestamp" in data
|
|
assert "tools" in data
|
|
|
|
|
|
class TestAutotestAPI:
|
|
"""Test cases for Autotest endpoints"""
|
|
|
|
def test_autotest_no_models(self, client):
|
|
"""Test autotest without specifying model IDs"""
|
|
response = client.post("/api/tools/autotest")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "id" in data
|
|
assert "tests" in data
|
|
assert "summary" in data
|
|
# Should have test failures since no models provided
|
|
assert data["summary"]["total"] > 0
|
|
|
|
def test_autotest_with_llm_model(self, client, sample_llm_model_data):
|
|
"""Test autotest with an LLM model"""
|
|
# Create an LLM model first
|
|
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
|
model_id = create_response.json()["id"]
|
|
|
|
# Run autotest
|
|
response = client.post(f"/api/tools/autotest?llm_model_id={model_id}&test_asr=false")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "tests" in data
|
|
assert "summary" in data
|
|
|
|
def test_autotest_with_asr_model(self, client, sample_asr_model_data):
|
|
"""Test autotest with an ASR model"""
|
|
# Create an ASR model first
|
|
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
|
model_id = create_response.json()["id"]
|
|
|
|
# Run autotest
|
|
response = client.post(f"/api/tools/autotest?asr_model_id={model_id}&test_llm=false")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "tests" in data
|
|
assert "summary" in data
|
|
|
|
def test_autotest_with_both_models(self, client, sample_llm_model_data, sample_asr_model_data):
|
|
"""Test autotest with both LLM and ASR models"""
|
|
# Create models
|
|
llm_response = client.post("/api/llm", json=sample_llm_model_data)
|
|
llm_id = llm_response.json()["id"]
|
|
|
|
asr_response = client.post("/api/asr", json=sample_asr_model_data)
|
|
asr_id = asr_response.json()["id"]
|
|
|
|
# Run autotest
|
|
response = client.post(
|
|
f"/api/tools/autotest?llm_model_id={llm_id}&asr_model_id={asr_id}"
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "tests" in data
|
|
assert "summary" in data
|
|
|
|
@patch('httpx.Client')
|
|
def test_autotest_llm_model_success(self, mock_client_class, client, sample_llm_model_data):
|
|
"""Test autotest for a specific LLM model with successful connection"""
|
|
# Create an LLM model first
|
|
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
|
model_id = create_response.json()["id"]
|
|
|
|
# Mock the HTTP response for successful connection
|
|
mock_client = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"choices": [{"message": {"content": "OK"}}]
|
|
}
|
|
mock_response.raise_for_status = MagicMock()
|
|
mock_response.iter_bytes = MagicMock(return_value=[b'chunk1', b'chunk2'])
|
|
mock_client.post.return_value = mock_response
|
|
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
|
mock_client.__exit__ = MagicMock(return_value=False)
|
|
|
|
with patch('app.routers.tools.httpx.Client', return_value=mock_client):
|
|
response = client.post(f"/api/tools/autotest/llm/{model_id}")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "tests" in data
|
|
assert "summary" in data
|
|
|
|
@patch('httpx.Client')
|
|
def test_autotest_asr_model_success(self, mock_client_class, client, sample_asr_model_data):
|
|
"""Test autotest for a specific ASR model with successful connection"""
|
|
# Create an ASR model first
|
|
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
|
model_id = create_response.json()["id"]
|
|
|
|
# Mock the HTTP response for successful connection
|
|
mock_client = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.raise_for_status = MagicMock()
|
|
mock_client.get.return_value = mock_response
|
|
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
|
mock_client.__exit__ = MagicMock(return_value=False)
|
|
|
|
with patch('app.routers.tools.httpx.Client', return_value=mock_client):
|
|
response = client.post(f"/api/tools/autotest/asr/{model_id}")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "tests" in data
|
|
assert "summary" in data
|
|
|
|
def test_autotest_llm_model_not_found(self, client):
|
|
"""Test autotest for a non-existent LLM model"""
|
|
response = client.post("/api/tools/autotest/llm/non-existent-id")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
# Should have a failure test
|
|
assert any(not t["passed"] for t in data["tests"])
|
|
|
|
def test_autotest_asr_model_not_found(self, client):
|
|
"""Test autotest for a non-existent ASR model"""
|
|
response = client.post("/api/tools/autotest/asr/non-existent-id")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
# Should have a failure test
|
|
assert any(not t["passed"] for t in data["tests"])
|
|
|
|
@patch('httpx.Client')
|
|
def test_test_message_success(self, mock_client_class, client, sample_llm_model_data):
|
|
"""Test sending a test message to an LLM model"""
|
|
# Create an LLM model first
|
|
create_response = client.post("/api/llm", json=sample_llm_model_data)
|
|
model_id = create_response.json()["id"]
|
|
|
|
# Mock the HTTP response
|
|
mock_client = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"choices": [{"message": {"content": "Hello! This is a test reply."}}],
|
|
"usage": {"total_tokens": 10}
|
|
}
|
|
mock_response.raise_for_status = MagicMock()
|
|
mock_client.post.return_value = mock_response
|
|
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
|
mock_client.__exit__ = MagicMock(return_value=False)
|
|
|
|
with patch('app.routers.tools.httpx.Client', return_value=mock_client):
|
|
response = client.post(
|
|
f"/api/tools/test-message?llm_model_id={model_id}",
|
|
json={"message": "Hello!"}
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] == True
|
|
assert "reply" in data
|
|
|
|
def test_test_message_model_not_found(self, client):
|
|
"""Test sending a test message to a non-existent model"""
|
|
response = client.post(
|
|
"/api/tools/test-message?llm_model_id=non-existent",
|
|
json={"message": "Hello!"}
|
|
)
|
|
assert response.status_code == 404
|
|
|
|
def test_autotest_result_structure(self, client):
|
|
"""Test that autotest results have the correct structure"""
|
|
response = client.post("/api/tools/autotest")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
# Check required fields
|
|
assert "id" in data
|
|
assert "started_at" in data
|
|
assert "duration_ms" in data
|
|
assert "tests" in data
|
|
assert "summary" in data
|
|
|
|
# Check summary structure
|
|
assert "passed" in data["summary"]
|
|
assert "failed" in data["summary"]
|
|
assert "total" in data["summary"]
|
|
|
|
# Check test structure
|
|
if data["tests"]:
|
|
test = data["tests"][0]
|
|
assert "name" in test
|
|
assert "passed" in test
|
|
assert "message" in test
|
|
assert "duration_ms" in test
|
|
|
|
def test_tools_have_required_fields(self, client):
|
|
"""Test that all tools have required fields"""
|
|
response = client.get("/api/tools/list")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
for tool_id, tool in data["tools"].items():
|
|
assert "name" in tool
|
|
assert "description" in tool
|
|
assert "parameters" in tool
|
|
|
|
# Check parameters structure
|
|
params = tool["parameters"]
|
|
assert "type" in params
|
|
assert "properties" in params
|
|
|
|
def test_calculator_tool_parameters(self, client):
|
|
"""Test calculator tool has correct parameters"""
|
|
response = client.get("/api/tools/list/calculator")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert data["name"] == "计算器"
|
|
assert "expression" in data["parameters"]["properties"]
|
|
assert "required" in data["parameters"]
|
|
assert "expression" in data["parameters"]["required"]
|
|
|
|
def test_code_interpreter_tool_parameters(self, client):
|
|
"""Test code_interpreter tool has correct parameters"""
|
|
response = client.get("/api/tools/list/code_interpreter")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert data["name"] == "代码执行"
|
|
assert "code" in data["parameters"]["properties"]
|
|
|
|
|
|
class TestToolResourceCRUD:
|
|
"""Test cases for persistent tool resource CRUD endpoints."""
|
|
|
|
def test_list_tool_resources_contains_system_tools(self, client):
|
|
response = client.get("/api/tools/resources")
|
|
assert response.status_code == 200
|
|
payload = response.json()
|
|
assert payload["total"] >= 1
|
|
ids = [item["id"] for item in payload["list"]]
|
|
assert "calculator" in ids
|
|
|
|
def test_create_update_delete_tool_resource(self, client):
|
|
create_resp = client.post("/api/tools/resources", json={
|
|
"name": "自定义网页抓取",
|
|
"description": "抓取页面并提取正文",
|
|
"category": "query",
|
|
"icon": "Globe",
|
|
"http_method": "GET",
|
|
"http_url": "https://example.com/search",
|
|
"http_headers": {},
|
|
"http_timeout_ms": 10000,
|
|
"enabled": True,
|
|
})
|
|
assert create_resp.status_code == 200
|
|
created = create_resp.json()
|
|
tool_id = created["id"]
|
|
assert created["name"] == "自定义网页抓取"
|
|
assert created["is_system"] is False
|
|
|
|
update_resp = client.put(f"/api/tools/resources/{tool_id}", json={
|
|
"name": "自定义网页检索",
|
|
"category": "system",
|
|
})
|
|
assert update_resp.status_code == 200
|
|
updated = update_resp.json()
|
|
assert updated["name"] == "自定义网页检索"
|
|
assert updated["category"] == "system"
|
|
|
|
get_resp = client.get(f"/api/tools/resources/{tool_id}")
|
|
assert get_resp.status_code == 200
|
|
assert get_resp.json()["id"] == tool_id
|
|
|
|
delete_resp = client.delete(f"/api/tools/resources/{tool_id}")
|
|
assert delete_resp.status_code == 200
|
|
|
|
missing_resp = client.get(f"/api/tools/resources/{tool_id}")
|
|
assert missing_resp.status_code == 404
|
|
|
|
def test_create_query_tool_requires_http_url(self, client):
|
|
resp = client.post("/api/tools/resources", json={
|
|
"name": "缺失URL的查询工具",
|
|
"description": "应当失败",
|
|
"category": "query",
|
|
"icon": "Globe",
|
|
"enabled": True,
|
|
})
|
|
assert resp.status_code == 400
|
|
|
|
def test_system_tool_can_be_updated_and_deleted(self, client):
|
|
list_resp = client.get("/api/tools/resources")
|
|
assert list_resp.status_code == 200
|
|
assert any(item["id"] == "turn_on_camera" for item in list_resp.json()["list"])
|
|
|
|
update_resp = client.put("/api/tools/resources/turn_on_camera", json={"name": "更新后的打开摄像头", "category": "system"})
|
|
assert update_resp.status_code == 200
|
|
assert update_resp.json()["name"] == "更新后的打开摄像头"
|
|
|
|
delete_resp = client.delete("/api/tools/resources/turn_on_camera")
|
|
assert delete_resp.status_code == 200
|
|
|
|
get_resp = client.get("/api/tools/resources/turn_on_camera")
|
|
assert get_resp.status_code == 404
|