VAD fallback (#97)
* Silero VAD preferred with webrtc fallback * webrtc VAD neds a different sample size * fixup * fixup
This commit is contained in:
@@ -5,6 +5,7 @@ import signal
|
||||
import threading
|
||||
import types
|
||||
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
@@ -33,6 +34,11 @@ except ModuleNotFoundError as e:
|
||||
|
||||
from dailyai.transports.threaded_transport import ThreadedTransport
|
||||
|
||||
NUM_CHANNELS = 1
|
||||
|
||||
SPEECH_THRESHOLD = 0.90
|
||||
VAD_RESET_PERIOD_MS = 2000
|
||||
|
||||
|
||||
class DailyTransport(ThreadedTransport, EventHandler):
|
||||
_daily_initialized = False
|
||||
@@ -55,6 +61,7 @@ class DailyTransport(ThreadedTransport, EventHandler):
|
||||
start_transcription: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
kwargs['has_webrtc_vad'] = True
|
||||
# This will call ThreadedTransport.__init__ method, not EventHandler
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -86,6 +93,12 @@ class DailyTransport(ThreadedTransport, EventHandler):
|
||||
|
||||
self._event_handlers = {}
|
||||
|
||||
self.webrtc_vad = Daily.create_native_vad(
|
||||
reset_period_ms=VAD_RESET_PERIOD_MS,
|
||||
sample_rate=self._speaker_sample_rate,
|
||||
channels=NUM_CHANNELS
|
||||
)
|
||||
|
||||
def _patch_method(self, event_name, *args, **kwargs):
|
||||
try:
|
||||
for handler in self._event_handlers[event_name]:
|
||||
@@ -106,6 +119,18 @@ class DailyTransport(ThreadedTransport, EventHandler):
|
||||
self._logger.error(f"Exception in event handler {event_name}: {e}")
|
||||
raise e
|
||||
|
||||
def _webrtc_vad_analyze(self):
|
||||
buffer = self.read_audio_frames(
|
||||
int(self._vad_samples))
|
||||
if len(buffer) > 0:
|
||||
confidence = self.webrtc_vad.analyze_frames(buffer)
|
||||
# yeses = int(confidence * 20.0)
|
||||
# nos = 20 - yeses
|
||||
# out = "!" * yeses + "." * nos
|
||||
# print(f"!!! confidence: {out} {confidence}")
|
||||
talking = confidence > SPEECH_THRESHOLD
|
||||
return talking
|
||||
|
||||
def add_event_handler(self, event_name: str, handler):
|
||||
if not event_name.startswith("on_"):
|
||||
raise Exception(
|
||||
|
||||
@@ -40,9 +40,6 @@ def int2float(sound):
|
||||
return sound
|
||||
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
|
||||
|
||||
class VADState(Enum):
|
||||
QUIET = 1
|
||||
STARTING = 2
|
||||
@@ -61,11 +58,12 @@ class ThreadedTransport(AbstractTransport):
|
||||
self._vad_stop_s = kwargs.get("vad_stop_s") or 0.8
|
||||
self._context = kwargs.get("context") or []
|
||||
self._vad_enabled = kwargs.get("vad_enabled") or False
|
||||
|
||||
self._has_webrtc_vad = kwargs.get("has_webrtc_vad") or False
|
||||
if self._vad_enabled and self._speaker_enabled:
|
||||
raise Exception(
|
||||
"Sorry, you can't use speaker_enabled and vad_enabled at the same time. Please set one to False."
|
||||
)
|
||||
self._vad_samples = 1536
|
||||
|
||||
if self._vad_enabled:
|
||||
try:
|
||||
@@ -79,14 +77,19 @@ class ThreadedTransport(AbstractTransport):
|
||||
(self.model, self.utils) = torch.hub.load(
|
||||
repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=False
|
||||
)
|
||||
self._logger.debug("Loaded Silero VAD")
|
||||
|
||||
except ModuleNotFoundError as e:
|
||||
print(f"Exception: {e}")
|
||||
print("In order to use VAD, you'll need to install the `torch` and `torchaudio` modules.")
|
||||
raise Exception(f"Missing module(s): {e}")
|
||||
if self._has_webrtc_vad:
|
||||
self._logger.debug(f"Couldn't load torch; using webrtc VAD")
|
||||
self._vad_samples = int(self._speaker_sample_rate / 100.0)
|
||||
else:
|
||||
self._logger.error(f"Exception: {e}")
|
||||
self._logger.error(
|
||||
"In order to use VAD, you'll need to install the `torch` and `torchaudio` modules.")
|
||||
raise Exception(f"Missing module(s): {e}")
|
||||
|
||||
self._vad_samples = 1536
|
||||
vad_frame_s = self._vad_samples / SAMPLE_RATE
|
||||
vad_frame_s = self._vad_samples / self._speaker_sample_rate
|
||||
self._vad_start_frames = round(self._vad_start_s / vad_frame_s)
|
||||
self._vad_stop_frames = round(self._vad_stop_s / vad_frame_s)
|
||||
self._vad_starting_count = 0
|
||||
@@ -262,19 +265,28 @@ class ThreadedTransport(AbstractTransport):
|
||||
def _prerun(self):
|
||||
pass
|
||||
|
||||
def _vad(self):
|
||||
# CB: Starting silero VAD stuff
|
||||
# TODO-CB: Probably need to force virtual speaker creation if we're
|
||||
# going to build this in?
|
||||
# TODO-CB: pyaudio installation
|
||||
while not self._stop_threads.is_set():
|
||||
audio_chunk = self.read_audio_frames(self._vad_samples)
|
||||
audio_int16 = np.frombuffer(audio_chunk, np.int16)
|
||||
audio_float32 = int2float(audio_int16)
|
||||
new_confidence = self.model(
|
||||
torch.from_numpy(audio_float32), 16000).item()
|
||||
speaking = new_confidence > 0.5
|
||||
def _silero_vad_analyze(self):
|
||||
audio_chunk = self.read_audio_frames(self._vad_samples)
|
||||
audio_int16 = np.frombuffer(audio_chunk, np.int16)
|
||||
audio_float32 = int2float(audio_int16)
|
||||
new_confidence = self.model(
|
||||
torch.from_numpy(audio_float32), 16000).item()
|
||||
# yeses = int(new_confidence * 20.0)
|
||||
# nos = 20 - yeses
|
||||
# out = "!" * yeses + "." * nos
|
||||
# print(f"!!! confidence: {out}")
|
||||
speaking = new_confidence > 0.5
|
||||
return speaking
|
||||
|
||||
def _vad(self):
|
||||
|
||||
while not self._stop_threads.is_set():
|
||||
if hasattr(self, 'model'): # we can use Silero
|
||||
speaking = self._silero_vad_analyze()
|
||||
elif self._has_webrtc_vad:
|
||||
speaking = self._webrtc_vad_analyze()
|
||||
else:
|
||||
raise Exception("VAD is running with no VAD service available")
|
||||
if speaking:
|
||||
match self._vad_state:
|
||||
case VADState.QUIET:
|
||||
@@ -311,6 +323,7 @@ class ThreadedTransport(AbstractTransport):
|
||||
self._vad_state == VADState.STOPPING
|
||||
and self._vad_stopping_count >= self._vad_stop_frames
|
||||
):
|
||||
|
||||
if self._loop:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.receive_queue.put(
|
||||
|
||||
Reference in New Issue
Block a user