139 lines
3.3 KiB
Python
139 lines
3.3 KiB
Python
"""Pytest fixtures for API tests"""
|
|
import os
|
|
import sys
|
|
import pytest
|
|
|
|
# Add api directory to path
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from fastapi.testclient import TestClient
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
from app.db import Base, get_db
|
|
from app.main import app
|
|
|
|
|
|
# Use in-memory SQLite for testing
|
|
DATABASE_URL = "sqlite:///:memory:"
|
|
|
|
engine = create_engine(
|
|
DATABASE_URL,
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def db_session():
|
|
"""Create a fresh database session for each test"""
|
|
# Create all tables
|
|
Base.metadata.create_all(bind=engine)
|
|
|
|
session = TestingSessionLocal()
|
|
try:
|
|
yield session
|
|
finally:
|
|
session.close()
|
|
# Drop all tables after test
|
|
Base.metadata.drop_all(bind=engine)
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def client(db_session):
|
|
"""Create a test client with database dependency override"""
|
|
|
|
def override_get_db():
|
|
try:
|
|
yield db_session
|
|
finally:
|
|
pass
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
|
|
with TestClient(app) as test_client:
|
|
yield test_client
|
|
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_voice_data():
|
|
"""Sample voice data for testing"""
|
|
return {
|
|
"name": "Test Voice",
|
|
"vendor": "TestVendor",
|
|
"gender": "Female",
|
|
"language": "zh",
|
|
"description": "A test voice for unit testing",
|
|
"model": "test-model",
|
|
"voice_key": "test-key",
|
|
"speed": 1.0,
|
|
"gain": 0,
|
|
"pitch": 0,
|
|
"enabled": True
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_assistant_data():
|
|
"""Sample assistant data for testing"""
|
|
return {
|
|
"name": "Test Assistant",
|
|
"opener": "Hello, welcome!",
|
|
"prompt": "You are a helpful assistant.",
|
|
"language": "zh",
|
|
"voiceOutputEnabled": True,
|
|
"speed": 1.0,
|
|
"hotwords": ["test", "hello"],
|
|
"tools": [],
|
|
"configMode": "platform"
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_call_record_data():
|
|
"""Sample call record data for testing"""
|
|
return {
|
|
"user_id": 1,
|
|
"assistant_id": None,
|
|
"source": "debug"
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_llm_model_data():
|
|
"""Sample LLM model data for testing"""
|
|
return {
|
|
"id": "test-llm-001",
|
|
"name": "Test LLM Model",
|
|
"vendor": "TestVendor",
|
|
"type": "text",
|
|
"base_url": "https://api.test.com/v1",
|
|
"api_key": "test-api-key",
|
|
"model_name": "test-model",
|
|
"temperature": 0.7,
|
|
"context_length": 4096,
|
|
"enabled": True
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_asr_model_data():
|
|
"""Sample ASR model data for testing"""
|
|
return {
|
|
"id": "test-asr-001",
|
|
"name": "Test ASR Model",
|
|
"vendor": "TestVendor",
|
|
"language": "zh",
|
|
"base_url": "https://api.test.com/v1",
|
|
"api_key": "test-api-key",
|
|
"model_name": "paraformer-v2",
|
|
"hotwords": ["测试", "语音"],
|
|
"enable_punctuation": True,
|
|
"enable_normalization": True,
|
|
"enabled": True
|
|
}
|