Merge pull request #3684 from ai-coustics/goedev/aic-model-caching
AIC model caching
This commit is contained in:
3
changelog/3684.changed.md
Normal file
3
changelog/3684.changed.md
Normal file
@@ -0,0 +1,3 @@
|
||||
- `AICFilter` now shares read-only AIC models via a singleton `AICModelManager` in `aic_filter.py`.
|
||||
- Multiple filters using the same model path or `(model_id, model_download_dir)` share one loaded model, with reference counting and concurrent load deduplication.
|
||||
- Model file I/O runs off the event loop so the filter does not block.
|
||||
@@ -12,10 +12,13 @@ the Koala filter and integrates with Pipecat's input transport pipeline.
|
||||
|
||||
Classes:
|
||||
AICFilter: For aic-sdk (uses 'aic_sdk' module)
|
||||
AICModelManager: Singleton manager for read-only AIC Model instances.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from threading import Lock
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from aic_sdk import (
|
||||
@@ -33,6 +36,177 @@ from pipecat.audio.vad.aic_vad import AICVADAnalyzer
|
||||
from pipecat.frames.frames import FilterControlFrame, FilterEnableFrame
|
||||
|
||||
|
||||
class AICModelManager:
|
||||
"""Singleton manager for read-only AIC Model instances with reference counting.
|
||||
|
||||
Caches Model instances by path or (model_id + download_dir). Multiple
|
||||
AICFilter instances using the same model share one Model; the manager
|
||||
acquires on first use and releases when the last reference is dropped.
|
||||
"""
|
||||
|
||||
_cache: dict[str, Tuple[Model, int]] = {} # key -> (model, ref_count)
|
||||
_lock = Lock()
|
||||
_loading: dict[
|
||||
str, asyncio.Task[Model]
|
||||
] = {} # key -> load task (deduplicates concurrent loads)
|
||||
|
||||
@classmethod
|
||||
def _increment_reference(cls, cache_key: str, entry: Tuple[Model, int]) -> Tuple[Model, str]:
|
||||
"""Increment reference count for cached entry. Caller must hold _lock."""
|
||||
cached_model, ref_count = entry
|
||||
cls._cache[cache_key] = (cached_model, ref_count + 1)
|
||||
logger.debug(f"AIC model cache key={cache_key!r} ref_count={ref_count + 1}")
|
||||
return cached_model, cache_key
|
||||
|
||||
@classmethod
|
||||
def _store_new_reference(cls, cache_key: str, model: Model) -> Tuple[Model, str]:
|
||||
"""Store new model in cache with ref count 1. Caller must hold _lock."""
|
||||
cls._cache[cache_key] = (model, 1)
|
||||
logger.debug(f"AIC model cached key={cache_key!r} ref_count=1")
|
||||
return model, cache_key
|
||||
|
||||
@classmethod
|
||||
async def _load_model_from_file(
|
||||
cls,
|
||||
cache_key: str,
|
||||
*,
|
||||
model_path: Optional[Path] = None,
|
||||
model_id: Optional[str] = None,
|
||||
model_download_dir: Optional[Path] = None,
|
||||
) -> Model:
|
||||
"""Run the actual load (file or download). Separate to allow create_task and deduplication."""
|
||||
if model_path is not None:
|
||||
logger.debug(f"Loading AIC model from file: {model_path}")
|
||||
model_path_str = str(model_path)
|
||||
|
||||
elif model_id is not None and model_download_dir is not None:
|
||||
logger.debug(f"Downloading AIC model: {model_id}")
|
||||
model_download_dir.mkdir(parents=True, exist_ok=True)
|
||||
model_path_str = await Model.download_async(model_id, str(model_download_dir))
|
||||
logger.debug(f"Model downloaded to: {model_path_str}")
|
||||
|
||||
else:
|
||||
raise ValueError("Unexpected model_path or (model_id and model_download_dir) state.")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, lambda: Model.from_file(model_path_str))
|
||||
|
||||
@staticmethod
|
||||
def _get_cache_key(
|
||||
*,
|
||||
model_path: Optional[Path] = None,
|
||||
model_id: Optional[str] = None,
|
||||
model_download_dir: Optional[Path] = None,
|
||||
) -> str:
|
||||
"""Build a stable cache key for the model.
|
||||
|
||||
Args:
|
||||
model_path: Path to a local .aicmodel file.
|
||||
model_id: Model identifier (See https://artifacts.ai-coustics.io/ for available models).
|
||||
model_download_dir: Directory used for downloading models.
|
||||
|
||||
Returns:
|
||||
A string key unique per (path) or (model_id + download_dir).
|
||||
"""
|
||||
if model_path is not None:
|
||||
return f"path:{model_path.resolve()}"
|
||||
|
||||
if model_id is not None and model_download_dir is not None:
|
||||
return f"id:{model_id}:{model_download_dir.resolve()}"
|
||||
|
||||
raise ValueError("Either model_path or (model_id and model_download_dir) must be set.")
|
||||
|
||||
@classmethod
|
||||
async def acquire(
|
||||
cls,
|
||||
*,
|
||||
model_path: Optional[Path] = None,
|
||||
model_id: Optional[str] = None,
|
||||
model_download_dir: Optional[Path] = None,
|
||||
) -> Tuple[Model, str]:
|
||||
"""Get or load a Model and increment its reference count.
|
||||
|
||||
Call this when starting a filter. Store the returned key and pass it
|
||||
to release() when stopping the filter.
|
||||
|
||||
Args:
|
||||
model_path: Path to a local .aicmodel file. If set, model_id is ignored.
|
||||
model_id: Model identifier to download from CDN.
|
||||
model_download_dir: Directory for downloading models. Required if
|
||||
model_id is used.
|
||||
|
||||
Returns:
|
||||
Tuple of (shared Model instance, cache key for release).
|
||||
|
||||
Raises:
|
||||
ValueError: If neither model_path nor (model_id + model_download_dir)
|
||||
is provided, or if model_id is set without model_download_dir.
|
||||
"""
|
||||
cache_key = cls._get_cache_key(
|
||||
model_path=model_path,
|
||||
model_id=model_id,
|
||||
model_download_dir=model_download_dir,
|
||||
)
|
||||
|
||||
with cls._lock:
|
||||
entry = cls._cache.get(cache_key)
|
||||
if entry is not None:
|
||||
return cls._increment_reference(cache_key, entry)
|
||||
|
||||
# Deduplicate concurrent loads for the same key
|
||||
load_task = cls._loading.get(cache_key)
|
||||
if load_task is None:
|
||||
load_task = asyncio.create_task(
|
||||
cls._load_model_from_file(
|
||||
cache_key,
|
||||
model_path=model_path,
|
||||
model_id=model_id,
|
||||
model_download_dir=model_download_dir,
|
||||
)
|
||||
)
|
||||
cls._loading[cache_key] = load_task
|
||||
|
||||
try:
|
||||
model = await load_task
|
||||
finally:
|
||||
with cls._lock:
|
||||
cls._loading.pop(cache_key, None)
|
||||
|
||||
with cls._lock:
|
||||
entry = cls._cache.get(cache_key)
|
||||
if entry is not None:
|
||||
return cls._increment_reference(cache_key, entry)
|
||||
return cls._store_new_reference(cache_key, model)
|
||||
|
||||
@classmethod
|
||||
def release(cls, key: str) -> None:
|
||||
"""Release a reference to a cached model.
|
||||
|
||||
Call this when stopping a filter, with the key returned from
|
||||
get_model(). When the last reference is released, the model
|
||||
is removed from the cache.
|
||||
|
||||
Args:
|
||||
key: Cache key returned by get_model().
|
||||
"""
|
||||
with cls._lock:
|
||||
entry = cls._cache.get(key)
|
||||
|
||||
if entry is None:
|
||||
logger.warning(f"AIC model release unknown key={key!r}")
|
||||
return
|
||||
|
||||
model, ref_count = entry
|
||||
ref_count -= 1
|
||||
|
||||
if ref_count <= 0:
|
||||
del cls._cache[key]
|
||||
logger.debug(f"AIC model evicted key={key!r}")
|
||||
else:
|
||||
cls._cache[key] = (model, ref_count)
|
||||
logger.debug(f"AIC model key={key!r} ref_count={ref_count}")
|
||||
|
||||
|
||||
class AICFilter(BaseAudioFilter):
|
||||
"""Audio filter using ai-coustics' AIC SDK for real-time enhancement.
|
||||
|
||||
@@ -91,7 +265,8 @@ class AICFilter(BaseAudioFilter):
|
||||
32768.0 # 2^15, for normalizing int16 (-32768 to 32767) to float32 (-1.0 to 1.0)
|
||||
)
|
||||
|
||||
# AIC SDK objects
|
||||
# AIC SDK objects; model is shared via AICModelManager
|
||||
self._model_cache_key: Optional[str] = None
|
||||
self._model = None
|
||||
self._processor = None
|
||||
self._processor_ctx = None
|
||||
@@ -162,16 +337,12 @@ class AICFilter(BaseAudioFilter):
|
||||
"""
|
||||
self._sample_rate = sample_rate
|
||||
|
||||
# Load or download model
|
||||
if self._model_path:
|
||||
logger.debug(f"Loading AIC model from: {self._model_path}")
|
||||
self._model = Model.from_file(str(self._model_path))
|
||||
else:
|
||||
logger.debug(f"Downloading AIC model: {self._model_id}")
|
||||
self._model_download_dir.mkdir(parents=True, exist_ok=True)
|
||||
model_path = await Model.download_async(self._model_id, str(self._model_download_dir))
|
||||
logger.debug(f"Model downloaded to: {model_path}")
|
||||
self._model = Model.from_file(model_path)
|
||||
# Acquire shared read-only model from singleton manager
|
||||
self._model, self._model_cache_key = await AICModelManager.acquire(
|
||||
model_path=self._model_path,
|
||||
model_id=self._model_id,
|
||||
model_download_dir=self._model_download_dir,
|
||||
)
|
||||
|
||||
# Get optimal frames for this sample rate
|
||||
self._frames_per_block = self._model.get_optimal_num_frames(self._sample_rate)
|
||||
@@ -242,6 +413,10 @@ class AICFilter(BaseAudioFilter):
|
||||
self._aic_ready = False
|
||||
self._audio_buffer.clear()
|
||||
|
||||
if self._model_cache_key is not None:
|
||||
AICModelManager.release(self._model_cache_key)
|
||||
self._model_cache_key = None
|
||||
|
||||
async def process_frame(self, frame: FilterControlFrame):
|
||||
"""Process control frames to enable/disable filtering.
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
@@ -23,6 +24,13 @@ except ImportError:
|
||||
AIC_FILTER_MODULE = "pipecat.audio.filters.aic_filter"
|
||||
|
||||
|
||||
def _model_manager_ref_count(manager, key: str) -> int:
|
||||
"""Test helper: return reference count for a cache key (reads internal cache)."""
|
||||
with manager._lock:
|
||||
entry = manager._cache.get(key)
|
||||
return entry[1] if entry else 0
|
||||
|
||||
|
||||
class MockProcessor:
|
||||
"""A lightweight mock for AIC ProcessorAsync that mimics real behavior."""
|
||||
|
||||
@@ -99,10 +107,11 @@ class TestAICFilter(unittest.IsolatedAsyncioTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""Import AICFilter after confirming aic_sdk is available."""
|
||||
from pipecat.audio.filters.aic_filter import AICFilter
|
||||
from pipecat.audio.filters.aic_filter import AICFilter, AICModelManager
|
||||
from pipecat.frames.frames import FilterEnableFrame
|
||||
|
||||
cls.AICFilter = AICFilter
|
||||
cls.AICModelManager = AICModelManager
|
||||
cls.FilterEnableFrame = FilterEnableFrame
|
||||
|
||||
def setUp(self):
|
||||
@@ -122,13 +131,13 @@ class TestAICFilter(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def _start_filter_with_mocks(self, filter_instance, sample_rate=16000):
|
||||
"""Start a filter with mocked SDK components."""
|
||||
cache_key = "test-cache-key"
|
||||
with (
|
||||
patch(f"{AIC_FILTER_MODULE}.Model") as mock_model_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.AICModelManager") as mock_manager_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.ProcessorConfig") as mock_config_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.ProcessorAsync", return_value=self.mock_processor),
|
||||
):
|
||||
mock_model_cls.from_file.return_value = self.mock_model
|
||||
mock_model_cls.download_async = AsyncMock(return_value="/tmp/model")
|
||||
mock_manager_cls.acquire = AsyncMock(return_value=(self.mock_model, cache_key))
|
||||
mock_config_cls.optimal.return_value = MagicMock()
|
||||
await filter_instance.start(sample_rate)
|
||||
|
||||
@@ -171,37 +180,44 @@ class TestAICFilter(unittest.IsolatedAsyncioTestCase):
|
||||
filter_instance = self._create_filter_with_mocks(model_id=None, model_path=model_path)
|
||||
|
||||
with (
|
||||
patch(f"{AIC_FILTER_MODULE}.Model") as mock_model_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.AICModelManager") as mock_manager_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.ProcessorConfig") as mock_config_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.ProcessorAsync", return_value=self.mock_processor),
|
||||
):
|
||||
mock_model_cls.from_file.return_value = self.mock_model
|
||||
mock_manager_cls.acquire = AsyncMock(
|
||||
return_value=(self.mock_model, "path:/tmp/test.aicmodel")
|
||||
)
|
||||
mock_config_cls.optimal.return_value = MagicMock()
|
||||
|
||||
await filter_instance.start(16000)
|
||||
|
||||
mock_model_cls.from_file.assert_called_once_with(str(model_path))
|
||||
mock_manager_cls.acquire.assert_called_once()
|
||||
call_kw = mock_manager_cls.acquire.call_args[1]
|
||||
self.assertEqual(call_kw["model_path"], model_path)
|
||||
self.assertIsNone(call_kw["model_id"])
|
||||
self.assertTrue(filter_instance._aic_ready)
|
||||
self.assertEqual(filter_instance._sample_rate, 16000)
|
||||
self.assertEqual(filter_instance._frames_per_block, 160)
|
||||
|
||||
async def test_start_with_model_id_downloads(self):
|
||||
"""Test starting filter with model_id triggers download."""
|
||||
"""Test starting filter with model_id uses manager (download happens in manager)."""
|
||||
filter_instance = self._create_filter_with_mocks()
|
||||
|
||||
with (
|
||||
patch(f"{AIC_FILTER_MODULE}.Model") as mock_model_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.AICModelManager") as mock_manager_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.ProcessorConfig") as mock_config_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.ProcessorAsync", return_value=self.mock_processor),
|
||||
):
|
||||
mock_model_cls.from_file.return_value = self.mock_model
|
||||
mock_model_cls.download_async = AsyncMock(return_value="/tmp/model")
|
||||
mock_manager_cls.acquire = AsyncMock(
|
||||
return_value=(self.mock_model, "id:test-model:/custom/cache")
|
||||
)
|
||||
mock_config_cls.optimal.return_value = MagicMock()
|
||||
|
||||
await filter_instance.start(16000)
|
||||
|
||||
mock_model_cls.download_async.assert_called_once()
|
||||
mock_model_cls.from_file.assert_called_once()
|
||||
mock_manager_cls.acquire.assert_called_once()
|
||||
call_kw = mock_manager_cls.acquire.call_args[1]
|
||||
self.assertEqual(call_kw["model_id"], "test-model")
|
||||
self.assertTrue(filter_instance._aic_ready)
|
||||
|
||||
async def test_start_creates_processor(self):
|
||||
@@ -209,14 +225,13 @@ class TestAICFilter(unittest.IsolatedAsyncioTestCase):
|
||||
filter_instance = self._create_filter_with_mocks()
|
||||
|
||||
with (
|
||||
patch(f"{AIC_FILTER_MODULE}.Model") as mock_model_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.AICModelManager") as mock_manager_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.ProcessorConfig") as mock_config_cls,
|
||||
patch(
|
||||
f"{AIC_FILTER_MODULE}.ProcessorAsync", return_value=self.mock_processor
|
||||
) as mock_processor_cls,
|
||||
):
|
||||
mock_model_cls.from_file.return_value = self.mock_model
|
||||
mock_model_cls.download_async = AsyncMock(return_value="/tmp/model")
|
||||
mock_manager_cls.acquire = AsyncMock(return_value=(self.mock_model, "test-cache-key"))
|
||||
mock_config_cls.optimal.return_value = MagicMock()
|
||||
|
||||
await filter_instance.start(16000)
|
||||
@@ -241,17 +256,21 @@ class TestAICFilter(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(bypass_params[-1][1], 0.0)
|
||||
|
||||
async def test_stop_cleans_up_resources(self):
|
||||
"""Test that stop properly cleans up resources."""
|
||||
"""Test that stop properly cleans up resources and releases model reference."""
|
||||
filter_instance = self._create_filter_with_mocks()
|
||||
await self._start_filter_with_mocks(filter_instance)
|
||||
cache_key = filter_instance._model_cache_key
|
||||
|
||||
await filter_instance.stop()
|
||||
with patch(f"{AIC_FILTER_MODULE}.AICModelManager.release") as mock_release:
|
||||
await filter_instance.stop()
|
||||
|
||||
mock_release.assert_called_once_with(cache_key)
|
||||
self.assertTrue(self.mock_processor.processor_ctx.reset_called)
|
||||
self.assertIsNone(filter_instance._processor)
|
||||
self.assertIsNone(filter_instance._processor_ctx)
|
||||
self.assertIsNone(filter_instance._vad_ctx)
|
||||
self.assertIsNone(filter_instance._model)
|
||||
self.assertIsNone(filter_instance._model_cache_key)
|
||||
self.assertFalse(filter_instance._aic_ready)
|
||||
|
||||
async def test_stop_without_start(self):
|
||||
@@ -261,6 +280,177 @@ class TestAICFilter(unittest.IsolatedAsyncioTestCase):
|
||||
# Should not raise
|
||||
await filter_instance.stop()
|
||||
|
||||
async def test_model_manager_reference_count(self):
|
||||
"""Test that AICModelManager reference count increments and decrements correctly."""
|
||||
model_path = Path("/tmp/refcount-test.aicmodel")
|
||||
mock_model = MockModel()
|
||||
manager = self.AICModelManager
|
||||
|
||||
with patch(f"{AIC_FILTER_MODULE}.Model") as mock_model_cls:
|
||||
mock_model_cls.from_file.return_value = mock_model
|
||||
|
||||
# Acquire first reference
|
||||
model1, key = await manager.acquire(model_path=model_path)
|
||||
self.assertEqual(model1, mock_model)
|
||||
self.assertEqual(_model_manager_ref_count(manager, key), 1)
|
||||
|
||||
# Acquire second reference (same key, cached)
|
||||
model2, key2 = await manager.acquire(model_path=model_path)
|
||||
self.assertIs(model2, model1)
|
||||
self.assertEqual(key2, key)
|
||||
self.assertEqual(_model_manager_ref_count(manager, key), 2)
|
||||
|
||||
# Release one reference
|
||||
manager.release(key)
|
||||
self.assertEqual(_model_manager_ref_count(manager, key), 1)
|
||||
|
||||
# Release last reference (model evicted from cache)
|
||||
manager.release(key)
|
||||
self.assertEqual(_model_manager_ref_count(manager, key), 0)
|
||||
|
||||
async def test_model_manager_concurrent_load_deduplication(self):
|
||||
"""Test that concurrent acquire calls for the same key share a single load task."""
|
||||
model_path = Path("/tmp/concurrent-load-test.aicmodel")
|
||||
mock_model = MockModel()
|
||||
manager = self.AICModelManager
|
||||
load_count = 0
|
||||
|
||||
def from_file_once(path):
|
||||
nonlocal load_count
|
||||
load_count += 1
|
||||
time.sleep(0.02) # yield so other acquire callers can hit _loading and await same task
|
||||
return mock_model
|
||||
|
||||
with patch(f"{AIC_FILTER_MODULE}.Model") as mock_model_cls:
|
||||
mock_model_cls.from_file.side_effect = from_file_once
|
||||
|
||||
# Start several acquire calls concurrently before any completes
|
||||
results = await asyncio.gather(
|
||||
manager.acquire(model_path=model_path),
|
||||
manager.acquire(model_path=model_path),
|
||||
manager.acquire(model_path=model_path),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
load_count, 1, "Model.from_file should be called once for concurrent callers"
|
||||
)
|
||||
model1, key1 = results[0]
|
||||
model2, key2 = results[1]
|
||||
model3, key3 = results[2]
|
||||
self.assertIs(model1, mock_model)
|
||||
self.assertIs(model2, mock_model)
|
||||
self.assertIs(model3, mock_model)
|
||||
self.assertEqual(key1, key2)
|
||||
self.assertEqual(key2, key3)
|
||||
self.assertEqual(_model_manager_ref_count(manager, key1), 3)
|
||||
|
||||
# Release all references
|
||||
manager.release(key1)
|
||||
manager.release(key1)
|
||||
manager.release(key1)
|
||||
self.assertEqual(_model_manager_ref_count(manager, key1), 0)
|
||||
|
||||
async def test_load_model_from_file_invalid_args_raises(self):
|
||||
"""Test _load_model_from_file defensive else: raises ValueError."""
|
||||
manager = self.AICModelManager
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
await manager._load_model_from_file(
|
||||
"key",
|
||||
model_path=None,
|
||||
model_id=None,
|
||||
model_download_dir=None,
|
||||
)
|
||||
self.assertIn("Unexpected", str(ctx.exception))
|
||||
|
||||
async def test_model_manager_acquire_by_model_id_hits_download_path(self):
|
||||
"""Test acquire with model_id runs download path in _load_model_from_file."""
|
||||
model_id = "test-model-id"
|
||||
model_download_dir = Path("/tmp/aic-downloads")
|
||||
mock_model = MockModel()
|
||||
manager = self.AICModelManager
|
||||
|
||||
with patch(f"{AIC_FILTER_MODULE}.Model") as mock_model_cls:
|
||||
mock_model_cls.download_async = AsyncMock(
|
||||
return_value="/tmp/aic-downloads/model.aicmodel"
|
||||
)
|
||||
mock_model_cls.from_file.return_value = mock_model
|
||||
|
||||
model, key = await manager.acquire(
|
||||
model_id=model_id,
|
||||
model_download_dir=model_download_dir,
|
||||
)
|
||||
|
||||
mock_model_cls.download_async.assert_called_once()
|
||||
mock_model_cls.from_file.assert_called_once_with("/tmp/aic-downloads/model.aicmodel")
|
||||
self.assertIs(model, mock_model)
|
||||
self.assertEqual(_model_manager_ref_count(manager, key), 1)
|
||||
manager.release(key)
|
||||
|
||||
def test_get_cache_key_invalid_raises(self):
|
||||
"""Test _get_cache_key raises ValueError for invalid args."""
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
self.AICModelManager._get_cache_key(model_path=None, model_id=None)
|
||||
self.assertIn("model_path", str(ctx.exception))
|
||||
|
||||
with self.assertRaises(ValueError) as ctx2:
|
||||
self.AICModelManager._get_cache_key(
|
||||
model_path=None,
|
||||
model_id="x",
|
||||
model_download_dir=None,
|
||||
)
|
||||
self.assertIn("model_download_dir", str(ctx2.exception))
|
||||
|
||||
async def test_start_processor_init_failure(self):
|
||||
"""Test start() when ProcessorAsync raises: exception logged, _aic_ready False."""
|
||||
filter_instance = self._create_filter_with_mocks()
|
||||
|
||||
with (
|
||||
patch(f"{AIC_FILTER_MODULE}.AICModelManager") as mock_manager_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.ProcessorConfig") as mock_config_cls,
|
||||
patch(
|
||||
f"{AIC_FILTER_MODULE}.ProcessorAsync",
|
||||
side_effect=RuntimeError("SDK init failed"),
|
||||
),
|
||||
):
|
||||
mock_manager_cls.acquire = AsyncMock(return_value=(self.mock_model, "test-key"))
|
||||
mock_config_cls.optimal.return_value = MagicMock()
|
||||
|
||||
await filter_instance.start(16000)
|
||||
|
||||
self.assertIsNone(filter_instance._processor)
|
||||
self.assertFalse(filter_instance._aic_ready)
|
||||
|
||||
async def test_start_parameter_fixed_error_logged(self):
|
||||
"""Test start() when set_parameter raises ParameterFixedError: logged, no raise."""
|
||||
filter_instance = self._create_filter_with_mocks()
|
||||
self.mock_processor.processor_ctx.set_parameter = MagicMock(
|
||||
side_effect=aic_sdk.ParameterFixedError("fixed")
|
||||
)
|
||||
|
||||
with (
|
||||
patch(f"{AIC_FILTER_MODULE}.AICModelManager") as mock_manager_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.ProcessorConfig") as mock_config_cls,
|
||||
patch(f"{AIC_FILTER_MODULE}.ProcessorAsync", return_value=self.mock_processor),
|
||||
):
|
||||
mock_manager_cls.acquire = AsyncMock(return_value=(self.mock_model, "test-key"))
|
||||
mock_config_cls.optimal.return_value = MagicMock()
|
||||
|
||||
await filter_instance.start(16000)
|
||||
|
||||
self.assertTrue(filter_instance._aic_ready)
|
||||
|
||||
async def test_process_frame_set_parameter_exception_logged(self):
|
||||
"""Test process_frame when set_parameter raises: exception logged, no raise."""
|
||||
filter_instance = self._create_filter_with_mocks()
|
||||
await self._start_filter_with_mocks(filter_instance)
|
||||
filter_instance._processor_ctx.set_parameter = MagicMock(
|
||||
side_effect=ValueError("param error")
|
||||
)
|
||||
|
||||
await filter_instance.process_frame(self.FilterEnableFrame(enable=True))
|
||||
|
||||
self.assertFalse(filter_instance._bypass)
|
||||
|
||||
async def test_process_frame_enable(self):
|
||||
"""Test processing FilterEnableFrame to enable filtering."""
|
||||
filter_instance = self._create_filter_with_mocks()
|
||||
|
||||
Reference in New Issue
Block a user