From 2036757b84062dc0804f77ecf2466c39b3932461 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kmen=20G=C3=B6rgen?= Date: Wed, 11 Feb 2026 15:22:37 +0100 Subject: [PATCH] add unit tests for `AICModelManager` and `AICFilter` error handling, model loading, and processor behavior --- changelog/3684.changed.md | 3 +- src/pipecat/audio/filters/aic_filter.py | 15 ++-- tests/test_aic_filter.py | 101 ++++++++++++++++++++++++ 3 files changed, 112 insertions(+), 7 deletions(-) diff --git a/changelog/3684.changed.md b/changelog/3684.changed.md index 017fc9946..1bdb2c89c 100644 --- a/changelog/3684.changed.md +++ b/changelog/3684.changed.md @@ -1,2 +1,3 @@ - `AICFilter` now shares read-only AIC models via a singleton `AICModelManager` in `aic_filter.py`. - - Multiple filters using the same `model path` or `(model_id, model_download_dir)` share one loaded model, with reference counting and concurrent load deduplication. + - Multiple filters using the same model path or `(model_id, model_download_dir)` share one loaded model, with reference counting and concurrent load deduplication. + - Model file I/O runs off the event loop so the filter does not block. diff --git a/src/pipecat/audio/filters/aic_filter.py b/src/pipecat/audio/filters/aic_filter.py index 399e6cd2e..723b3da8f 100644 --- a/src/pipecat/audio/filters/aic_filter.py +++ b/src/pipecat/audio/filters/aic_filter.py @@ -77,16 +77,19 @@ class AICModelManager: """Run the actual load (file or download). Separate to allow create_task and deduplication.""" if model_path is not None: logger.debug(f"Loading AIC model from file: {model_path}") - return Model.from_file(str(model_path)) + model_path_str = str(model_path) - if model_id is not None and model_download_dir is not None: + elif model_id is not None and model_download_dir is not None: logger.debug(f"Downloading AIC model: {model_id}") model_download_dir.mkdir(parents=True, exist_ok=True) - path = await Model.download_async(model_id, str(model_download_dir)) - logger.debug(f"Model downloaded to: {path}") - return Model.from_file(path) + model_path_str = await Model.download_async(model_id, str(model_download_dir)) + logger.debug(f"Model downloaded to: {model_path_str}") - raise ValueError("Unexpected model_path or (model_id and model_download_dir) state.") + else: + raise ValueError("Unexpected model_path or (model_id and model_download_dir) state.") + + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, lambda: Model.from_file(model_path_str)) @staticmethod def _get_cache_key( diff --git a/tests/test_aic_filter.py b/tests/test_aic_filter.py index c36022a7b..83eca757c 100644 --- a/tests/test_aic_filter.py +++ b/tests/test_aic_filter.py @@ -350,6 +350,107 @@ class TestAICFilter(unittest.IsolatedAsyncioTestCase): manager.release(key1) self.assertEqual(_model_manager_ref_count(manager, key1), 0) + async def test_load_model_from_file_invalid_args_raises(self): + """Test _load_model_from_file defensive else: raises ValueError.""" + manager = self.AICModelManager + with self.assertRaises(ValueError) as ctx: + await manager._load_model_from_file( + "key", + model_path=None, + model_id=None, + model_download_dir=None, + ) + self.assertIn("Unexpected", str(ctx.exception)) + + async def test_model_manager_acquire_by_model_id_hits_download_path(self): + """Test acquire with model_id runs download path in _load_model_from_file.""" + model_id = "test-model-id" + model_download_dir = Path("/tmp/aic-downloads") + mock_model = MockModel() + manager = self.AICModelManager + + with patch(f"{AIC_FILTER_MODULE}.Model") as mock_model_cls: + mock_model_cls.download_async = AsyncMock( + return_value="/tmp/aic-downloads/model.aicmodel" + ) + mock_model_cls.from_file.return_value = mock_model + + model, key = await manager.acquire( + model_id=model_id, + model_download_dir=model_download_dir, + ) + + mock_model_cls.download_async.assert_called_once() + mock_model_cls.from_file.assert_called_once_with("/tmp/aic-downloads/model.aicmodel") + self.assertIs(model, mock_model) + self.assertEqual(_model_manager_ref_count(manager, key), 1) + manager.release(key) + + def test_get_cache_key_invalid_raises(self): + """Test _get_cache_key raises ValueError for invalid args.""" + with self.assertRaises(ValueError) as ctx: + self.AICModelManager._get_cache_key(model_path=None, model_id=None) + self.assertIn("model_path", str(ctx.exception)) + + with self.assertRaises(ValueError) as ctx2: + self.AICModelManager._get_cache_key( + model_path=None, + model_id="x", + model_download_dir=None, + ) + self.assertIn("model_download_dir", str(ctx2.exception)) + + async def test_start_processor_init_failure(self): + """Test start() when ProcessorAsync raises: exception logged, _aic_ready False.""" + filter_instance = self._create_filter_with_mocks() + + with ( + patch(f"{AIC_FILTER_MODULE}.AICModelManager") as mock_manager_cls, + patch(f"{AIC_FILTER_MODULE}.ProcessorConfig") as mock_config_cls, + patch( + f"{AIC_FILTER_MODULE}.ProcessorAsync", + side_effect=RuntimeError("SDK init failed"), + ), + ): + mock_manager_cls.acquire = AsyncMock(return_value=(self.mock_model, "test-key")) + mock_config_cls.optimal.return_value = MagicMock() + + await filter_instance.start(16000) + + self.assertIsNone(filter_instance._processor) + self.assertFalse(filter_instance._aic_ready) + + async def test_start_parameter_fixed_error_logged(self): + """Test start() when set_parameter raises ParameterFixedError: logged, no raise.""" + filter_instance = self._create_filter_with_mocks() + self.mock_processor.processor_ctx.set_parameter = MagicMock( + side_effect=aic_sdk.ParameterFixedError("fixed") + ) + + with ( + patch(f"{AIC_FILTER_MODULE}.AICModelManager") as mock_manager_cls, + patch(f"{AIC_FILTER_MODULE}.ProcessorConfig") as mock_config_cls, + patch(f"{AIC_FILTER_MODULE}.ProcessorAsync", return_value=self.mock_processor), + ): + mock_manager_cls.acquire = AsyncMock(return_value=(self.mock_model, "test-key")) + mock_config_cls.optimal.return_value = MagicMock() + + await filter_instance.start(16000) + + self.assertTrue(filter_instance._aic_ready) + + async def test_process_frame_set_parameter_exception_logged(self): + """Test process_frame when set_parameter raises: exception logged, no raise.""" + filter_instance = self._create_filter_with_mocks() + await self._start_filter_with_mocks(filter_instance) + filter_instance._processor_ctx.set_parameter = MagicMock( + side_effect=ValueError("param error") + ) + + await filter_instance.process_frame(self.FilterEnableFrame(enable=True)) + + self.assertFalse(filter_instance._bypass) + async def test_process_frame_enable(self): """Test processing FilterEnableFrame to enable filtering.""" filter_instance = self._create_filter_with_mocks()