Merge pull request #3684 from ai-coustics/goedev/aic-model-caching

AIC model caching
This commit is contained in:
Mark Backman
2026-02-16 10:43:14 -05:00
committed by GitHub
3 changed files with 398 additions and 30 deletions

View File

@@ -0,0 +1,3 @@
- `AICFilter` now shares read-only AIC models via a singleton `AICModelManager` in `aic_filter.py`.
- Multiple filters using the same model path or `(model_id, model_download_dir)` share one loaded model, with reference counting and concurrent load deduplication.
- Model file I/O runs off the event loop so the filter does not block.

View File

@@ -12,10 +12,13 @@ the Koala filter and integrates with Pipecat's input transport pipeline.
Classes:
AICFilter: For aic-sdk (uses 'aic_sdk' module)
AICModelManager: Singleton manager for read-only AIC Model instances.
"""
import asyncio
from pathlib import Path
from typing import List, Optional
from threading import Lock
from typing import List, Optional, Tuple
import numpy as np
from aic_sdk import (
@@ -33,6 +36,177 @@ from pipecat.audio.vad.aic_vad import AICVADAnalyzer
from pipecat.frames.frames import FilterControlFrame, FilterEnableFrame
class AICModelManager:
"""Singleton manager for read-only AIC Model instances with reference counting.
Caches Model instances by path or (model_id + download_dir). Multiple
AICFilter instances using the same model share one Model; the manager
acquires on first use and releases when the last reference is dropped.
"""
_cache: dict[str, Tuple[Model, int]] = {} # key -> (model, ref_count)
_lock = Lock()
_loading: dict[
str, asyncio.Task[Model]
] = {} # key -> load task (deduplicates concurrent loads)
@classmethod
def _increment_reference(cls, cache_key: str, entry: Tuple[Model, int]) -> Tuple[Model, str]:
"""Increment reference count for cached entry. Caller must hold _lock."""
cached_model, ref_count = entry
cls._cache[cache_key] = (cached_model, ref_count + 1)
logger.debug(f"AIC model cache key={cache_key!r} ref_count={ref_count + 1}")
return cached_model, cache_key
@classmethod
def _store_new_reference(cls, cache_key: str, model: Model) -> Tuple[Model, str]:
"""Store new model in cache with ref count 1. Caller must hold _lock."""
cls._cache[cache_key] = (model, 1)
logger.debug(f"AIC model cached key={cache_key!r} ref_count=1")
return model, cache_key
@classmethod
async def _load_model_from_file(
cls,
cache_key: str,
*,
model_path: Optional[Path] = None,
model_id: Optional[str] = None,
model_download_dir: Optional[Path] = None,
) -> Model:
"""Run the actual load (file or download). Separate to allow create_task and deduplication."""
if model_path is not None:
logger.debug(f"Loading AIC model from file: {model_path}")
model_path_str = str(model_path)
elif model_id is not None and model_download_dir is not None:
logger.debug(f"Downloading AIC model: {model_id}")
model_download_dir.mkdir(parents=True, exist_ok=True)
model_path_str = await Model.download_async(model_id, str(model_download_dir))
logger.debug(f"Model downloaded to: {model_path_str}")
else:
raise ValueError("Unexpected model_path or (model_id and model_download_dir) state.")
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, lambda: Model.from_file(model_path_str))
@staticmethod
def _get_cache_key(
*,
model_path: Optional[Path] = None,
model_id: Optional[str] = None,
model_download_dir: Optional[Path] = None,
) -> str:
"""Build a stable cache key for the model.
Args:
model_path: Path to a local .aicmodel file.
model_id: Model identifier (See https://artifacts.ai-coustics.io/ for available models).
model_download_dir: Directory used for downloading models.
Returns:
A string key unique per (path) or (model_id + download_dir).
"""
if model_path is not None:
return f"path:{model_path.resolve()}"
if model_id is not None and model_download_dir is not None:
return f"id:{model_id}:{model_download_dir.resolve()}"
raise ValueError("Either model_path or (model_id and model_download_dir) must be set.")
@classmethod
async def acquire(
cls,
*,
model_path: Optional[Path] = None,
model_id: Optional[str] = None,
model_download_dir: Optional[Path] = None,
) -> Tuple[Model, str]:
"""Get or load a Model and increment its reference count.
Call this when starting a filter. Store the returned key and pass it
to release() when stopping the filter.
Args:
model_path: Path to a local .aicmodel file. If set, model_id is ignored.
model_id: Model identifier to download from CDN.
model_download_dir: Directory for downloading models. Required if
model_id is used.
Returns:
Tuple of (shared Model instance, cache key for release).
Raises:
ValueError: If neither model_path nor (model_id + model_download_dir)
is provided, or if model_id is set without model_download_dir.
"""
cache_key = cls._get_cache_key(
model_path=model_path,
model_id=model_id,
model_download_dir=model_download_dir,
)
with cls._lock:
entry = cls._cache.get(cache_key)
if entry is not None:
return cls._increment_reference(cache_key, entry)
# Deduplicate concurrent loads for the same key
load_task = cls._loading.get(cache_key)
if load_task is None:
load_task = asyncio.create_task(
cls._load_model_from_file(
cache_key,
model_path=model_path,
model_id=model_id,
model_download_dir=model_download_dir,
)
)
cls._loading[cache_key] = load_task
try:
model = await load_task
finally:
with cls._lock:
cls._loading.pop(cache_key, None)
with cls._lock:
entry = cls._cache.get(cache_key)
if entry is not None:
return cls._increment_reference(cache_key, entry)
return cls._store_new_reference(cache_key, model)
@classmethod
def release(cls, key: str) -> None:
"""Release a reference to a cached model.
Call this when stopping a filter, with the key returned from
get_model(). When the last reference is released, the model
is removed from the cache.
Args:
key: Cache key returned by get_model().
"""
with cls._lock:
entry = cls._cache.get(key)
if entry is None:
logger.warning(f"AIC model release unknown key={key!r}")
return
model, ref_count = entry
ref_count -= 1
if ref_count <= 0:
del cls._cache[key]
logger.debug(f"AIC model evicted key={key!r}")
else:
cls._cache[key] = (model, ref_count)
logger.debug(f"AIC model key={key!r} ref_count={ref_count}")
class AICFilter(BaseAudioFilter):
"""Audio filter using ai-coustics' AIC SDK for real-time enhancement.
@@ -91,7 +265,8 @@ class AICFilter(BaseAudioFilter):
32768.0 # 2^15, for normalizing int16 (-32768 to 32767) to float32 (-1.0 to 1.0)
)
# AIC SDK objects
# AIC SDK objects; model is shared via AICModelManager
self._model_cache_key: Optional[str] = None
self._model = None
self._processor = None
self._processor_ctx = None
@@ -162,16 +337,12 @@ class AICFilter(BaseAudioFilter):
"""
self._sample_rate = sample_rate
# Load or download model
if self._model_path:
logger.debug(f"Loading AIC model from: {self._model_path}")
self._model = Model.from_file(str(self._model_path))
else:
logger.debug(f"Downloading AIC model: {self._model_id}")
self._model_download_dir.mkdir(parents=True, exist_ok=True)
model_path = await Model.download_async(self._model_id, str(self._model_download_dir))
logger.debug(f"Model downloaded to: {model_path}")
self._model = Model.from_file(model_path)
# Acquire shared read-only model from singleton manager
self._model, self._model_cache_key = await AICModelManager.acquire(
model_path=self._model_path,
model_id=self._model_id,
model_download_dir=self._model_download_dir,
)
# Get optimal frames for this sample rate
self._frames_per_block = self._model.get_optimal_num_frames(self._sample_rate)
@@ -242,6 +413,10 @@ class AICFilter(BaseAudioFilter):
self._aic_ready = False
self._audio_buffer.clear()
if self._model_cache_key is not None:
AICModelManager.release(self._model_cache_key)
self._model_cache_key = None
async def process_frame(self, frame: FilterControlFrame):
"""Process control frames to enable/disable filtering.

View File

@@ -5,6 +5,7 @@
#
import asyncio
import time
import unittest
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
@@ -23,6 +24,13 @@ except ImportError:
AIC_FILTER_MODULE = "pipecat.audio.filters.aic_filter"
def _model_manager_ref_count(manager, key: str) -> int:
"""Test helper: return reference count for a cache key (reads internal cache)."""
with manager._lock:
entry = manager._cache.get(key)
return entry[1] if entry else 0
class MockProcessor:
"""A lightweight mock for AIC ProcessorAsync that mimics real behavior."""
@@ -99,10 +107,11 @@ class TestAICFilter(unittest.IsolatedAsyncioTestCase):
@classmethod
def setUpClass(cls):
"""Import AICFilter after confirming aic_sdk is available."""
from pipecat.audio.filters.aic_filter import AICFilter
from pipecat.audio.filters.aic_filter import AICFilter, AICModelManager
from pipecat.frames.frames import FilterEnableFrame
cls.AICFilter = AICFilter
cls.AICModelManager = AICModelManager
cls.FilterEnableFrame = FilterEnableFrame
def setUp(self):
@@ -122,13 +131,13 @@ class TestAICFilter(unittest.IsolatedAsyncioTestCase):
async def _start_filter_with_mocks(self, filter_instance, sample_rate=16000):
"""Start a filter with mocked SDK components."""
cache_key = "test-cache-key"
with (
patch(f"{AIC_FILTER_MODULE}.Model") as mock_model_cls,
patch(f"{AIC_FILTER_MODULE}.AICModelManager") as mock_manager_cls,
patch(f"{AIC_FILTER_MODULE}.ProcessorConfig") as mock_config_cls,
patch(f"{AIC_FILTER_MODULE}.ProcessorAsync", return_value=self.mock_processor),
):
mock_model_cls.from_file.return_value = self.mock_model
mock_model_cls.download_async = AsyncMock(return_value="/tmp/model")
mock_manager_cls.acquire = AsyncMock(return_value=(self.mock_model, cache_key))
mock_config_cls.optimal.return_value = MagicMock()
await filter_instance.start(sample_rate)
@@ -171,37 +180,44 @@ class TestAICFilter(unittest.IsolatedAsyncioTestCase):
filter_instance = self._create_filter_with_mocks(model_id=None, model_path=model_path)
with (
patch(f"{AIC_FILTER_MODULE}.Model") as mock_model_cls,
patch(f"{AIC_FILTER_MODULE}.AICModelManager") as mock_manager_cls,
patch(f"{AIC_FILTER_MODULE}.ProcessorConfig") as mock_config_cls,
patch(f"{AIC_FILTER_MODULE}.ProcessorAsync", return_value=self.mock_processor),
):
mock_model_cls.from_file.return_value = self.mock_model
mock_manager_cls.acquire = AsyncMock(
return_value=(self.mock_model, "path:/tmp/test.aicmodel")
)
mock_config_cls.optimal.return_value = MagicMock()
await filter_instance.start(16000)
mock_model_cls.from_file.assert_called_once_with(str(model_path))
mock_manager_cls.acquire.assert_called_once()
call_kw = mock_manager_cls.acquire.call_args[1]
self.assertEqual(call_kw["model_path"], model_path)
self.assertIsNone(call_kw["model_id"])
self.assertTrue(filter_instance._aic_ready)
self.assertEqual(filter_instance._sample_rate, 16000)
self.assertEqual(filter_instance._frames_per_block, 160)
async def test_start_with_model_id_downloads(self):
"""Test starting filter with model_id triggers download."""
"""Test starting filter with model_id uses manager (download happens in manager)."""
filter_instance = self._create_filter_with_mocks()
with (
patch(f"{AIC_FILTER_MODULE}.Model") as mock_model_cls,
patch(f"{AIC_FILTER_MODULE}.AICModelManager") as mock_manager_cls,
patch(f"{AIC_FILTER_MODULE}.ProcessorConfig") as mock_config_cls,
patch(f"{AIC_FILTER_MODULE}.ProcessorAsync", return_value=self.mock_processor),
):
mock_model_cls.from_file.return_value = self.mock_model
mock_model_cls.download_async = AsyncMock(return_value="/tmp/model")
mock_manager_cls.acquire = AsyncMock(
return_value=(self.mock_model, "id:test-model:/custom/cache")
)
mock_config_cls.optimal.return_value = MagicMock()
await filter_instance.start(16000)
mock_model_cls.download_async.assert_called_once()
mock_model_cls.from_file.assert_called_once()
mock_manager_cls.acquire.assert_called_once()
call_kw = mock_manager_cls.acquire.call_args[1]
self.assertEqual(call_kw["model_id"], "test-model")
self.assertTrue(filter_instance._aic_ready)
async def test_start_creates_processor(self):
@@ -209,14 +225,13 @@ class TestAICFilter(unittest.IsolatedAsyncioTestCase):
filter_instance = self._create_filter_with_mocks()
with (
patch(f"{AIC_FILTER_MODULE}.Model") as mock_model_cls,
patch(f"{AIC_FILTER_MODULE}.AICModelManager") as mock_manager_cls,
patch(f"{AIC_FILTER_MODULE}.ProcessorConfig") as mock_config_cls,
patch(
f"{AIC_FILTER_MODULE}.ProcessorAsync", return_value=self.mock_processor
) as mock_processor_cls,
):
mock_model_cls.from_file.return_value = self.mock_model
mock_model_cls.download_async = AsyncMock(return_value="/tmp/model")
mock_manager_cls.acquire = AsyncMock(return_value=(self.mock_model, "test-cache-key"))
mock_config_cls.optimal.return_value = MagicMock()
await filter_instance.start(16000)
@@ -241,17 +256,21 @@ class TestAICFilter(unittest.IsolatedAsyncioTestCase):
self.assertEqual(bypass_params[-1][1], 0.0)
async def test_stop_cleans_up_resources(self):
"""Test that stop properly cleans up resources."""
"""Test that stop properly cleans up resources and releases model reference."""
filter_instance = self._create_filter_with_mocks()
await self._start_filter_with_mocks(filter_instance)
cache_key = filter_instance._model_cache_key
await filter_instance.stop()
with patch(f"{AIC_FILTER_MODULE}.AICModelManager.release") as mock_release:
await filter_instance.stop()
mock_release.assert_called_once_with(cache_key)
self.assertTrue(self.mock_processor.processor_ctx.reset_called)
self.assertIsNone(filter_instance._processor)
self.assertIsNone(filter_instance._processor_ctx)
self.assertIsNone(filter_instance._vad_ctx)
self.assertIsNone(filter_instance._model)
self.assertIsNone(filter_instance._model_cache_key)
self.assertFalse(filter_instance._aic_ready)
async def test_stop_without_start(self):
@@ -261,6 +280,177 @@ class TestAICFilter(unittest.IsolatedAsyncioTestCase):
# Should not raise
await filter_instance.stop()
async def test_model_manager_reference_count(self):
"""Test that AICModelManager reference count increments and decrements correctly."""
model_path = Path("/tmp/refcount-test.aicmodel")
mock_model = MockModel()
manager = self.AICModelManager
with patch(f"{AIC_FILTER_MODULE}.Model") as mock_model_cls:
mock_model_cls.from_file.return_value = mock_model
# Acquire first reference
model1, key = await manager.acquire(model_path=model_path)
self.assertEqual(model1, mock_model)
self.assertEqual(_model_manager_ref_count(manager, key), 1)
# Acquire second reference (same key, cached)
model2, key2 = await manager.acquire(model_path=model_path)
self.assertIs(model2, model1)
self.assertEqual(key2, key)
self.assertEqual(_model_manager_ref_count(manager, key), 2)
# Release one reference
manager.release(key)
self.assertEqual(_model_manager_ref_count(manager, key), 1)
# Release last reference (model evicted from cache)
manager.release(key)
self.assertEqual(_model_manager_ref_count(manager, key), 0)
async def test_model_manager_concurrent_load_deduplication(self):
"""Test that concurrent acquire calls for the same key share a single load task."""
model_path = Path("/tmp/concurrent-load-test.aicmodel")
mock_model = MockModel()
manager = self.AICModelManager
load_count = 0
def from_file_once(path):
nonlocal load_count
load_count += 1
time.sleep(0.02) # yield so other acquire callers can hit _loading and await same task
return mock_model
with patch(f"{AIC_FILTER_MODULE}.Model") as mock_model_cls:
mock_model_cls.from_file.side_effect = from_file_once
# Start several acquire calls concurrently before any completes
results = await asyncio.gather(
manager.acquire(model_path=model_path),
manager.acquire(model_path=model_path),
manager.acquire(model_path=model_path),
)
self.assertEqual(
load_count, 1, "Model.from_file should be called once for concurrent callers"
)
model1, key1 = results[0]
model2, key2 = results[1]
model3, key3 = results[2]
self.assertIs(model1, mock_model)
self.assertIs(model2, mock_model)
self.assertIs(model3, mock_model)
self.assertEqual(key1, key2)
self.assertEqual(key2, key3)
self.assertEqual(_model_manager_ref_count(manager, key1), 3)
# Release all references
manager.release(key1)
manager.release(key1)
manager.release(key1)
self.assertEqual(_model_manager_ref_count(manager, key1), 0)
async def test_load_model_from_file_invalid_args_raises(self):
"""Test _load_model_from_file defensive else: raises ValueError."""
manager = self.AICModelManager
with self.assertRaises(ValueError) as ctx:
await manager._load_model_from_file(
"key",
model_path=None,
model_id=None,
model_download_dir=None,
)
self.assertIn("Unexpected", str(ctx.exception))
async def test_model_manager_acquire_by_model_id_hits_download_path(self):
"""Test acquire with model_id runs download path in _load_model_from_file."""
model_id = "test-model-id"
model_download_dir = Path("/tmp/aic-downloads")
mock_model = MockModel()
manager = self.AICModelManager
with patch(f"{AIC_FILTER_MODULE}.Model") as mock_model_cls:
mock_model_cls.download_async = AsyncMock(
return_value="/tmp/aic-downloads/model.aicmodel"
)
mock_model_cls.from_file.return_value = mock_model
model, key = await manager.acquire(
model_id=model_id,
model_download_dir=model_download_dir,
)
mock_model_cls.download_async.assert_called_once()
mock_model_cls.from_file.assert_called_once_with("/tmp/aic-downloads/model.aicmodel")
self.assertIs(model, mock_model)
self.assertEqual(_model_manager_ref_count(manager, key), 1)
manager.release(key)
def test_get_cache_key_invalid_raises(self):
"""Test _get_cache_key raises ValueError for invalid args."""
with self.assertRaises(ValueError) as ctx:
self.AICModelManager._get_cache_key(model_path=None, model_id=None)
self.assertIn("model_path", str(ctx.exception))
with self.assertRaises(ValueError) as ctx2:
self.AICModelManager._get_cache_key(
model_path=None,
model_id="x",
model_download_dir=None,
)
self.assertIn("model_download_dir", str(ctx2.exception))
async def test_start_processor_init_failure(self):
"""Test start() when ProcessorAsync raises: exception logged, _aic_ready False."""
filter_instance = self._create_filter_with_mocks()
with (
patch(f"{AIC_FILTER_MODULE}.AICModelManager") as mock_manager_cls,
patch(f"{AIC_FILTER_MODULE}.ProcessorConfig") as mock_config_cls,
patch(
f"{AIC_FILTER_MODULE}.ProcessorAsync",
side_effect=RuntimeError("SDK init failed"),
),
):
mock_manager_cls.acquire = AsyncMock(return_value=(self.mock_model, "test-key"))
mock_config_cls.optimal.return_value = MagicMock()
await filter_instance.start(16000)
self.assertIsNone(filter_instance._processor)
self.assertFalse(filter_instance._aic_ready)
async def test_start_parameter_fixed_error_logged(self):
"""Test start() when set_parameter raises ParameterFixedError: logged, no raise."""
filter_instance = self._create_filter_with_mocks()
self.mock_processor.processor_ctx.set_parameter = MagicMock(
side_effect=aic_sdk.ParameterFixedError("fixed")
)
with (
patch(f"{AIC_FILTER_MODULE}.AICModelManager") as mock_manager_cls,
patch(f"{AIC_FILTER_MODULE}.ProcessorConfig") as mock_config_cls,
patch(f"{AIC_FILTER_MODULE}.ProcessorAsync", return_value=self.mock_processor),
):
mock_manager_cls.acquire = AsyncMock(return_value=(self.mock_model, "test-key"))
mock_config_cls.optimal.return_value = MagicMock()
await filter_instance.start(16000)
self.assertTrue(filter_instance._aic_ready)
async def test_process_frame_set_parameter_exception_logged(self):
"""Test process_frame when set_parameter raises: exception logged, no raise."""
filter_instance = self._create_filter_with_mocks()
await self._start_filter_with_mocks(filter_instance)
filter_instance._processor_ctx.set_parameter = MagicMock(
side_effect=ValueError("param error")
)
await filter_instance.process_frame(self.FilterEnableFrame(enable=True))
self.assertFalse(filter_instance._bypass)
async def test_process_frame_enable(self):
"""Test processing FilterEnableFrame to enable filtering."""
filter_instance = self._create_filter_with_mocks()