vad: introduce VADParams so you can tweak things

This commit is contained in:
Aleix Conchillo Flaqué
2024-05-17 13:01:43 -07:00
parent efa5a061d7
commit 537e72a05f
4 changed files with 29 additions and 21 deletions

View File

@@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added initial interruptions support.
- Added `VADParams` so you can control voice confidence level and others.
### Fixed
- Fixed issues with Ctrl-C program termination.

View File

@@ -38,7 +38,7 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.vad.vad_analyzer import VADAnalyzer, VADState
from pipecat.vad.vad_analyzer import VADAnalyzer, VADParams, VADState
from loguru import logger
@@ -60,8 +60,8 @@ class DailyTransportMessageFrame(TransportMessageFrame):
class WebRTCVADAnalyzer(VADAnalyzer):
def __init__(self, sample_rate=16000, num_channels=1):
super().__init__(sample_rate, num_channels)
def __init__(self, sample_rate=16000, num_channels=1, params: VADParams = VADParams()):
super().__init__(sample_rate, num_channels, params)
self._webrtc_vad = Daily.create_native_vad(
reset_period_ms=VAD_RESET_PERIOD_MS,

View File

@@ -8,7 +8,7 @@ import numpy as np
from pipecat.frames.frames import AudioRawFrame, Frame, UserStartedSpeakingFrame, UserStoppedSpeakingFrame
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.vad.vad_analyzer import VADAnalyzer, VADState
from pipecat.vad.vad_analyzer import VADAnalyzer, VADParams, VADState
from loguru import logger
@@ -28,8 +28,8 @@ except ModuleNotFoundError as e:
class SileroVADAnalyzer(VADAnalyzer):
def __init__(self, sample_rate=16000):
super().__init__(sample_rate=sample_rate, num_channels=1)
def __init__(self, sample_rate=16000, params: VADParams = VADParams()):
super().__init__(sample_rate=sample_rate, num_channels=1, params=params)
logger.debug("Loading Silero VAD model...")
@@ -63,10 +63,14 @@ class SileroVADAnalyzer(VADAnalyzer):
class SileroVAD(FrameProcessor):
def __init__(self, sample_rate=16000, audio_passthrough=False):
def __init__(
self,
sample_rate: int = 16000,
vad_params: VADParams = VADParams(),
audio_passthrough: bool = False):
super().__init__()
self._vad_analyzer = SileroVADAnalyzer(sample_rate=sample_rate)
self._vad_analyzer = SileroVADAnalyzer(sample_rate=sample_rate, params=vad_params)
self._audio_passthrough = audio_passthrough
#

View File

@@ -7,6 +7,10 @@
from abc import abstractmethod
from enum import Enum
from pydantic.main import BaseModel
from pipecat.utils.utils import exp_smoothing
class VADState(Enum):
QUIET = 1
@@ -15,26 +19,24 @@ class VADState(Enum):
STOPPING = 4
class VADParams(BaseModel):
confidence: float = 0.8
start_secs: float = 0.2
stop_secs: float = 0.8
class VADAnalyzer:
def __init__(
self,
sample_rate: int,
num_channels: int,
vad_confidence: float = 0.5,
vad_start_secs: float = 0.2,
vad_stop_secs: float = 0.8):
def __init__(self, sample_rate: int, num_channels: int, params: VADParams):
self._sample_rate = sample_rate
self._vad_confidence = vad_confidence
self._vad_start_secs = vad_start_secs
self._vad_stop_secs = vad_stop_secs
self._params = params
self._vad_frames = self.num_frames_required()
self._vad_frames_num_bytes = self._vad_frames * num_channels * 2
vad_frames_per_sec = self._vad_frames / self._sample_rate
self._vad_start_frames = round(self._vad_start_secs / vad_frames_per_sec)
self._vad_stop_frames = round(self._vad_stop_secs / vad_frames_per_sec)
self._vad_start_frames = round(self._params.start_secs / vad_frames_per_sec)
self._vad_stop_frames = round(self._params.stop_secs / vad_frames_per_sec)
self._vad_starting_count = 0
self._vad_stopping_count = 0
self._vad_state: VADState = VADState.QUIET
@@ -64,7 +66,7 @@ class VADAnalyzer:
self._vad_buffer = self._vad_buffer[num_required_bytes:]
confidence = self.voice_confidence(audio_frames)
speaking = confidence >= self._vad_confidence
speaking = confidence >= self._params.confidence
if speaking:
match self._vad_state: