BaseInputTransport: create VAD thread in VADAnalyzer

We move the thread creation to the VADAnalyzer instead of the input
transport. This can potentially be useful if we need to analyze multiple audio
streams.
This commit is contained in:
Aleix Conchillo Flaqué
2025-09-25 11:37:16 -07:00
parent de3461e4cc
commit f53fd880dc
2 changed files with 14 additions and 9 deletions

View File

@@ -11,7 +11,9 @@ data structures for voice activity detection in audio streams. Includes state
management, parameter configuration, and audio analysis framework.
"""
import asyncio
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from typing import Optional
@@ -84,6 +86,10 @@ class VADAnalyzer(ABC):
self._smoothing_factor = 0.2
self._prev_volume = 0
# Thread executor that will run the model. We only need one thread per
# analyzer because one analyzer just handles one audio stream.
self._executor = ThreadPoolExecutor(max_workers=1)
@property
def sample_rate(self) -> int:
"""Get the current sample rate.
@@ -165,7 +171,7 @@ class VADAnalyzer(ABC):
volume = calculate_audio_volume(audio, self.sample_rate)
return exp_smoothing(volume, self._prev_volume, self._smoothing_factor)
def analyze_audio(self, buffer) -> VADState:
async def analyze_audio(self, buffer: bytes) -> VADState:
"""Analyze audio buffer and return current VAD state.
Processes incoming audio data, maintains internal state, and determines
@@ -177,6 +183,12 @@ class VADAnalyzer(ABC):
Returns:
Current VAD state after processing the buffer.
"""
loop = asyncio.get_running_loop()
state = await loop.run_in_executor(self._executor, self._run_analyzer, buffer)
return state
def _run_analyzer(self, buffer: bytes) -> VADState:
"""Analyze audio buffer and return current VAD state."""
self._vad_buffer += buffer
num_required_bytes = self._vad_frames_num_bytes

View File

@@ -11,7 +11,6 @@ input processing, including VAD, turn analysis, and interruption management.
"""
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
from loguru import logger
@@ -79,10 +78,6 @@ class BaseInputTransport(FrameProcessor):
# Track user speaking state for interruption logic
self._user_speaking = False
# We read audio from a single queue one at a time and we then run VAD in
# a thread. Therefore, only one thread should be necessary.
self._executor = ThreadPoolExecutor(max_workers=1)
# Task to process incoming audio (VAD) and push audio frames downstream
# if passthrough is enabled.
self._audio_task = None
@@ -398,9 +393,7 @@ class BaseInputTransport(FrameProcessor):
"""Analyze audio frame for voice activity."""
state = VADState.QUIET
if self.vad_analyzer:
state = await self.get_event_loop().run_in_executor(
self._executor, self.vad_analyzer.analyze_audio, audio_frame.audio
)
state = await self.vad_analyzer.analyze_audio(audio_frame.audio)
return state
async def _handle_vad(self, audio_frame: InputAudioRawFrame, vad_state: VADState) -> VADState: