add unit tests for AICModelManager and AICFilter error handling, model loading, and processor behavior
This commit is contained in:
@@ -1,2 +1,3 @@
|
||||
- `AICFilter` now shares read-only AIC models via a singleton `AICModelManager` in `aic_filter.py`.
|
||||
- Multiple filters using the same `model path` or `(model_id, model_download_dir)` share one loaded model, with reference counting and concurrent load deduplication.
|
||||
- Multiple filters using the same model path or `(model_id, model_download_dir)` share one loaded model, with reference counting and concurrent load deduplication.
|
||||
- Model file I/O runs off the event loop so the filter does not block.
|
||||
|
||||
@@ -77,16 +77,19 @@ class AICModelManager:
|
||||
"""Run the actual load (file or download). Separate to allow create_task and deduplication."""
|
||||
if model_path is not None:
|
||||
logger.debug(f"Loading AIC model from file: {model_path}")
|
||||
return Model.from_file(str(model_path))
|
||||
model_path_str = str(model_path)
|
||||
|
||||
if model_id is not None and model_download_dir is not None:
|
||||
elif model_id is not None and model_download_dir is not None:
|
||||
logger.debug(f"Downloading AIC model: {model_id}")
|
||||
model_download_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = await Model.download_async(model_id, str(model_download_dir))
|
||||
logger.debug(f"Model downloaded to: {path}")
|
||||
return Model.from_file(path)
|
||||
model_path_str = await Model.download_async(model_id, str(model_download_dir))
|
||||
logger.debug(f"Model downloaded to: {model_path_str}")
|
||||
|
||||
raise ValueError("Unexpected model_path or (model_id and model_download_dir) state.")
|
||||
else:
|
||||
raise ValueError("Unexpected model_path or (model_id and model_download_dir) state.")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, lambda: Model.from_file(model_path_str))
|
||||
|
||||
@staticmethod
|
||||
def _get_cache_key(
|
||||
|
||||
@@ -350,6 +350,107 @@ class TestAICFilter(unittest.IsolatedAsyncioTestCase):
|
||||
manager.release(key1)
|
||||
self.assertEqual(_model_manager_ref_count(manager, key1), 0)
|
||||
|
||||
async def test_load_model_from_file_invalid_args_raises(self):
|
||||
"""Test _load_model_from_file defensive else: raises ValueError."""
|
||||
manager = self.AICModelManager
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
await manager._load_model_from_file(
|
||||
"key",
|
||||
model_path=None,
|
||||
model_id=None,
|
||||
model_download_dir=None,
|
||||
)
|
||||
self.assertIn("Unexpected", str(ctx.exception))
|
||||
|
||||
async def test_model_manager_acquire_by_model_id_hits_download_path(self):
|
||||
"""Test acquire with model_id runs download path in _load_model_from_file."""
|
||||
model_id = "test-model-id"
|
||||
model_download_dir = Path("/tmp/aic-downloads")
|
||||
mock_model = MockModel()
|
||||
manager = self.AICModelManager
|
||||
|
||||
with patch(f"{AIC_FILTER_MODULE}.Model") as mock_model_cls:
|
||||
mock_model_cls.download_async = AsyncMock(
|
||||
return_value="/tmp/aic-downloads/model.aicmodel"
|
||||
)
|
||||
mock_model_cls.from_file.return_value = mock_model
|
||||
|
||||
model, key = await manager.acquire(
|
||||
model_id=model_id,
|
||||
model_download_dir=model_download_dir,
|
||||
)
|
||||
|
||||
mock_model_cls.download_async.assert_called_once()
|
||||
mock_model_cls.from_file.assert_called_once_with("/tmp/aic-downloads/model.aicmodel")
|
||||
self.assertIs(model, mock_model)
|
||||
self.assertEqual(_model_manager_ref_count(manager, key), 1)
|
||||
manager.release(key)
|
||||
|
||||
def test_get_cache_key_invalid_raises(self):
|
||||
"""Test _get_cache_key raises ValueError for invalid args."""
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
self.AICModelManager._get_cache_key(model_path=None, model_id=None)
|
||||
self.assertIn("model_path", str(ctx.exception))
|
||||
|
||||
with self.assertRaises(ValueError) as ctx2:
|
||||
self.AICModelManager._get_cache_key(
|
||||
model_path=None,
|
||||
model_id="x",
|
||||
model_download_dir=None,
|
||||
)
|
||||
self.assertIn("model_download_dir", str(ctx2.exception))
|
||||
|
||||
async def test_start_processor_init_failure(self):
|
||||
"""Test start() when ProcessorAsync raises: exception logged, _aic_ready False."""
|
||||
filter_instance = self._create_filter_with_mocks()
|
||||
|
||||
with (
|
||||
patch(f"{AIC_FILTER_MODULE}.AICModelManager") as mock_manager_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.ProcessorConfig") as mock_config_cls,
|
||||
patch(
|
||||
f"{AIC_FILTER_MODULE}.ProcessorAsync",
|
||||
side_effect=RuntimeError("SDK init failed"),
|
||||
),
|
||||
):
|
||||
mock_manager_cls.acquire = AsyncMock(return_value=(self.mock_model, "test-key"))
|
||||
mock_config_cls.optimal.return_value = MagicMock()
|
||||
|
||||
await filter_instance.start(16000)
|
||||
|
||||
self.assertIsNone(filter_instance._processor)
|
||||
self.assertFalse(filter_instance._aic_ready)
|
||||
|
||||
async def test_start_parameter_fixed_error_logged(self):
|
||||
"""Test start() when set_parameter raises ParameterFixedError: logged, no raise."""
|
||||
filter_instance = self._create_filter_with_mocks()
|
||||
self.mock_processor.processor_ctx.set_parameter = MagicMock(
|
||||
side_effect=aic_sdk.ParameterFixedError("fixed")
|
||||
)
|
||||
|
||||
with (
|
||||
patch(f"{AIC_FILTER_MODULE}.AICModelManager") as mock_manager_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.ProcessorConfig") as mock_config_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.ProcessorAsync", return_value=self.mock_processor),
|
||||
):
|
||||
mock_manager_cls.acquire = AsyncMock(return_value=(self.mock_model, "test-key"))
|
||||
mock_config_cls.optimal.return_value = MagicMock()
|
||||
|
||||
await filter_instance.start(16000)
|
||||
|
||||
self.assertTrue(filter_instance._aic_ready)
|
||||
|
||||
async def test_process_frame_set_parameter_exception_logged(self):
|
||||
"""Test process_frame when set_parameter raises: exception logged, no raise."""
|
||||
filter_instance = self._create_filter_with_mocks()
|
||||
await self._start_filter_with_mocks(filter_instance)
|
||||
filter_instance._processor_ctx.set_parameter = MagicMock(
|
||||
side_effect=ValueError("param error")
|
||||
)
|
||||
|
||||
await filter_instance.process_frame(self.FilterEnableFrame(enable=True))
|
||||
|
||||
self.assertFalse(filter_instance._bypass)
|
||||
|
||||
async def test_process_frame_enable(self):
|
||||
"""Test processing FilterEnableFrame to enable filtering."""
|
||||
filter_instance = self._create_filter_with_mocks()
|
||||
|
||||
Reference in New Issue
Block a user