BaseInputTransport: create VAD thread in VADAnalyzer
We move the thread creation to the VADAnalyzer instead of the input transport. This can potentially be useful if we need to analyze multiple audio streams.
This commit is contained in:
@@ -11,7 +11,9 @@ data structures for voice activity detection in audio streams. Includes state
|
||||
management, parameter configuration, and audio analysis framework.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
@@ -84,6 +86,10 @@ class VADAnalyzer(ABC):
|
||||
self._smoothing_factor = 0.2
|
||||
self._prev_volume = 0
|
||||
|
||||
# Thread executor that will run the model. We only need one thread per
|
||||
# analyzer because one analyzer just handles one audio stream.
|
||||
self._executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
"""Get the current sample rate.
|
||||
@@ -165,7 +171,7 @@ class VADAnalyzer(ABC):
|
||||
volume = calculate_audio_volume(audio, self.sample_rate)
|
||||
return exp_smoothing(volume, self._prev_volume, self._smoothing_factor)
|
||||
|
||||
def analyze_audio(self, buffer) -> VADState:
|
||||
async def analyze_audio(self, buffer: bytes) -> VADState:
|
||||
"""Analyze audio buffer and return current VAD state.
|
||||
|
||||
Processes incoming audio data, maintains internal state, and determines
|
||||
@@ -177,6 +183,12 @@ class VADAnalyzer(ABC):
|
||||
Returns:
|
||||
Current VAD state after processing the buffer.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
state = await loop.run_in_executor(self._executor, self._run_analyzer, buffer)
|
||||
return state
|
||||
|
||||
def _run_analyzer(self, buffer: bytes) -> VADState:
|
||||
"""Analyze audio buffer and return current VAD state."""
|
||||
self._vad_buffer += buffer
|
||||
|
||||
num_required_bytes = self._vad_frames_num_bytes
|
||||
|
||||
@@ -11,7 +11,6 @@ input processing, including VAD, turn analysis, and interruption management.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
@@ -79,10 +78,6 @@ class BaseInputTransport(FrameProcessor):
|
||||
# Track user speaking state for interruption logic
|
||||
self._user_speaking = False
|
||||
|
||||
# We read audio from a single queue one at a time and we then run VAD in
|
||||
# a thread. Therefore, only one thread should be necessary.
|
||||
self._executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
# Task to process incoming audio (VAD) and push audio frames downstream
|
||||
# if passthrough is enabled.
|
||||
self._audio_task = None
|
||||
@@ -398,9 +393,7 @@ class BaseInputTransport(FrameProcessor):
|
||||
"""Analyze audio frame for voice activity."""
|
||||
state = VADState.QUIET
|
||||
if self.vad_analyzer:
|
||||
state = await self.get_event_loop().run_in_executor(
|
||||
self._executor, self.vad_analyzer.analyze_audio, audio_frame.audio
|
||||
)
|
||||
state = await self.vad_analyzer.analyze_audio(audio_frame.audio)
|
||||
return state
|
||||
|
||||
async def _handle_vad(self, audio_frame: InputAudioRawFrame, vad_state: VADState) -> VADState:
|
||||
|
||||
Reference in New Issue
Block a user