diff --git a/CHANGELOG.md b/CHANGELOG.md index b1c9876ee..c381ac630 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added initial interruptions support. +- Added `VADParams` so you can control voice confidence level and others. + ### Fixed - Fixed issues with Ctrl-C program termination. diff --git a/src/pipecat/transports/services/daily.py b/src/pipecat/transports/services/daily.py index e0d406d8e..47a1b9925 100644 --- a/src/pipecat/transports/services/daily.py +++ b/src/pipecat/transports/services/daily.py @@ -38,7 +38,7 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor from pipecat.transports.base_input import BaseInputTransport from pipecat.transports.base_output import BaseOutputTransport from pipecat.transports.base_transport import BaseTransport, TransportParams -from pipecat.vad.vad_analyzer import VADAnalyzer, VADState +from pipecat.vad.vad_analyzer import VADAnalyzer, VADParams, VADState from loguru import logger @@ -60,8 +60,8 @@ class DailyTransportMessageFrame(TransportMessageFrame): class WebRTCVADAnalyzer(VADAnalyzer): - def __init__(self, sample_rate=16000, num_channels=1): - super().__init__(sample_rate, num_channels) + def __init__(self, sample_rate=16000, num_channels=1, params: VADParams = VADParams()): + super().__init__(sample_rate, num_channels, params) self._webrtc_vad = Daily.create_native_vad( reset_period_ms=VAD_RESET_PERIOD_MS, diff --git a/src/pipecat/vad/silero.py b/src/pipecat/vad/silero.py index bfe13affe..ab7cf36df 100644 --- a/src/pipecat/vad/silero.py +++ b/src/pipecat/vad/silero.py @@ -8,7 +8,7 @@ import numpy as np from pipecat.frames.frames import AudioRawFrame, Frame, UserStartedSpeakingFrame, UserStoppedSpeakingFrame from pipecat.processors.frame_processor import FrameDirection, FrameProcessor -from pipecat.vad.vad_analyzer import VADAnalyzer, VADState +from pipecat.vad.vad_analyzer import VADAnalyzer, VADParams, VADState from loguru import logger @@ -28,8 +28,8 @@ except ModuleNotFoundError as e: class SileroVADAnalyzer(VADAnalyzer): - def __init__(self, sample_rate=16000): - super().__init__(sample_rate=sample_rate, num_channels=1) + def __init__(self, sample_rate=16000, params: VADParams = VADParams()): + super().__init__(sample_rate=sample_rate, num_channels=1, params=params) logger.debug("Loading Silero VAD model...") @@ -63,10 +63,14 @@ class SileroVADAnalyzer(VADAnalyzer): class SileroVAD(FrameProcessor): - def __init__(self, sample_rate=16000, audio_passthrough=False): + def __init__( + self, + sample_rate: int = 16000, + vad_params: VADParams = VADParams(), + audio_passthrough: bool = False): super().__init__() - self._vad_analyzer = SileroVADAnalyzer(sample_rate=sample_rate) + self._vad_analyzer = SileroVADAnalyzer(sample_rate=sample_rate, params=vad_params) self._audio_passthrough = audio_passthrough # diff --git a/src/pipecat/vad/vad_analyzer.py b/src/pipecat/vad/vad_analyzer.py index 58bec3b9a..6c4afceba 100644 --- a/src/pipecat/vad/vad_analyzer.py +++ b/src/pipecat/vad/vad_analyzer.py @@ -7,6 +7,10 @@ from abc import abstractmethod from enum import Enum +from pydantic.main import BaseModel + +from pipecat.utils.utils import exp_smoothing + class VADState(Enum): QUIET = 1 @@ -15,26 +19,24 @@ class VADState(Enum): STOPPING = 4 +class VADParams(BaseModel): + confidence: float = 0.8 + start_secs: float = 0.2 + stop_secs: float = 0.8 + + class VADAnalyzer: - def __init__( - self, - sample_rate: int, - num_channels: int, - vad_confidence: float = 0.5, - vad_start_secs: float = 0.2, - vad_stop_secs: float = 0.8): + def __init__(self, sample_rate: int, num_channels: int, params: VADParams): self._sample_rate = sample_rate - self._vad_confidence = vad_confidence - self._vad_start_secs = vad_start_secs - self._vad_stop_secs = vad_stop_secs + self._params = params self._vad_frames = self.num_frames_required() self._vad_frames_num_bytes = self._vad_frames * num_channels * 2 vad_frames_per_sec = self._vad_frames / self._sample_rate - self._vad_start_frames = round(self._vad_start_secs / vad_frames_per_sec) - self._vad_stop_frames = round(self._vad_stop_secs / vad_frames_per_sec) + self._vad_start_frames = round(self._params.start_secs / vad_frames_per_sec) + self._vad_stop_frames = round(self._params.stop_secs / vad_frames_per_sec) self._vad_starting_count = 0 self._vad_stopping_count = 0 self._vad_state: VADState = VADState.QUIET @@ -64,7 +66,7 @@ class VADAnalyzer: self._vad_buffer = self._vad_buffer[num_required_bytes:] confidence = self.voice_confidence(audio_frames) - speaking = confidence >= self._vad_confidence + speaking = confidence >= self._params.confidence if speaking: match self._vad_state: