Isolated changes to add VAD (#32)

* added VAD

* added separate 'vad enabled' property
This commit is contained in:
chadbailey59
2024-02-28 15:16:44 -06:00
committed by GitHub
parent f710aeae95
commit d90fdb1cae
5 changed files with 157 additions and 10 deletions

View File

@@ -13,10 +13,13 @@ dependencies = [
"fal",
"faster_whisper",
"google-cloud-texttospeech",
"numpy",
"openai",
"Pillow",
"pyht",
"python-dotenv",
"torch",
"pyaudio",
"typing-extensions"
]

View File

@@ -58,3 +58,9 @@ class LLMMessagesQueueFrame(QueueFrame):
class AppMessageQueueFrame(QueueFrame):
message: Any
participantId: str
class UserStartedSpeakingFrame(QueueFrame):
pass
class UserStoppedSpeakingFrame(QueueFrame):
pass

View File

@@ -2,10 +2,15 @@ from abc import abstractmethod
import asyncio
import itertools
import logging
import numpy as np
import pyaudio
import torch
import torchaudio
import queue
import threading
import time
from typing import AsyncGenerator
from enum import Enum
from dailyai.queue_frame import (
AudioQueueFrame,
@@ -14,8 +19,57 @@ from dailyai.queue_frame import (
QueueFrame,
SpriteQueueFrame,
StartStreamQueueFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame
)
torch.set_num_threads(1)
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False)
(get_speech_timestamps,
save_audio,
read_audio,
VADIterator,
collect_chunks) = utils
# Taken from utils_vad.py
def validate(model,
inputs: torch.Tensor):
with torch.no_grad():
outs = model(inputs)
return outs
# Provided by Alexander Veysov
def int2float(sound):
abs_max = np.abs(sound).max()
sound = sound.astype('float32')
if abs_max > 0:
sound *= 1/32768
sound = sound.squeeze() # depends on the use case
return sound
FORMAT = pyaudio.paInt16
CHANNELS = 1
SAMPLE_RATE = 16000
CHUNK = int(SAMPLE_RATE / 10)
audio = pyaudio.PyAudio()
class VADState(Enum):
QUIET = 1
STARTING = 2
SPEAKING = 3
STOPPING = 4
class BaseTransportService():
@@ -31,7 +85,23 @@ class BaseTransportService():
self._speaker_enabled = kwargs.get("speaker_enabled") or False
self._speaker_sample_rate = kwargs.get("speaker_sample_rate") or 16000
self._fps = kwargs.get("fps") or 8
self._vad_start_s = kwargs.get("vad_start_s") or 0.2
self._vad_stop_s = kwargs.get("vad_stop_s") or 0.8
self._context = kwargs.get("context") or []
self._vad_enabled = kwargs.get("vad_enabled") or False
if self._vad_enabled and self._speaker_enabled:
raise Exception("Sorry, you can't use speaker_enabled and vad_enabled at the same time. Please set one to False.")
self._vad_samples = 1536
vad_frame_s = self._vad_samples / SAMPLE_RATE
self._vad_start_frames = round(self._vad_start_s / vad_frame_s)
self._vad_stop_frames = round(self._vad_stop_s / vad_frame_s)
self._vad_starting_count = 0
self._vad_stopping_count = 0
self._vad_state = VADState.QUIET
self._user_is_speaking = False
duration_minutes = kwargs.get("duration_minutes") or 10
self._expiration = time.time() + duration_minutes * 60
@@ -66,6 +136,10 @@ class BaseTransportService():
if self._speaker_enabled:
self._receive_audio_thread = threading.Thread(target=self._receive_audio, daemon=True)
self._receive_audio_thread.start()
if self._vad_enabled:
self._vad_thread = threading.Thread(target=self._vad, daemon=True)
self._vad_thread.start()
try:
while (
@@ -89,6 +163,10 @@ class BaseTransportService():
if self._speaker_enabled:
self._receive_audio_thread.join()
if self._vad_enabled:
self._vad_thread.join()
def _post_run(self):
# Note that this function must be idempotent! It can be called multiple times
@@ -121,7 +199,57 @@ class BaseTransportService():
@abstractmethod
def _prerun(self):
pass
def _vad(self):
# CB: Starting silero VAD stuff
# TODO-CB: Probably need to force virtual speaker creation if we're
# going to build this in?
# TODO-CB: pyaudio installation
while not self._stop_threads.is_set():
audio_chunk = self.read_audio_frames(self._vad_samples)
audio_int16 = np.frombuffer(audio_chunk, np.int16)
audio_float32 = int2float(audio_int16)
new_confidence = model(
torch.from_numpy(audio_float32), 16000).item()
speaking = new_confidence > 0.5
if speaking:
match self._vad_state:
case VADState.QUIET:
self._vad_state = VADState.STARTING
self._vad_starting_count = 1
case VADState.STARTING:
self._vad_starting_count += 1
case VADState.STOPPING:
self._vad_state = VADState.SPEAKING
self._vad_stopping_count = 0
else:
match self._vad_state:
case VADState.STARTING:
self._vad_state = VADState.QUIET
self._vad_starting_count = 0
case VADState.SPEAKING:
self._vad_state = VADState.STOPPING
self._vad_stopping_count = 1
case VADState.STOPPING:
self._vad_stopping_count += 1
if self._vad_state == VADState.STARTING and self._vad_starting_count >= self._vad_start_frames:
asyncio.run_coroutine_threadsafe(
self.receive_queue.put(
UserStartedSpeakingFrame()), self._loop
)
# self.interrupt()
self._vad_state = VADState.SPEAKING
self._vad_starting_count = 0
if self._vad_state == VADState.STOPPING and self._vad_stopping_count >= self._vad_stop_frames:
asyncio.run_coroutine_threadsafe(
self.receive_queue.put(
UserStoppedSpeakingFrame()), self._loop
)
self._vad_state = VADState.QUIET
self._vad_stopping_count = 0
async def _marshal_frames(self):
while True:
frame: QueueFrame | list = await self.send_queue.get()

View File

@@ -31,6 +31,7 @@ class DailyTransportService(BaseTransportService, EventHandler):
_speaker_enabled: bool
_speaker_sample_rate: int
_vad_enabled: bool
# This is necessary to override EventHandler's __new__ method.
def __new__(cls, *args, **kwargs):
@@ -142,7 +143,7 @@ class DailyTransportService(BaseTransportService, EventHandler):
"camera", width=self._camera_width, height=self._camera_height, color_format="RGB"
)
if self._speaker_enabled:
if self._speaker_enabled or self._vad_enabled:
self._speaker: VirtualSpeakerDevice = Daily.create_speaker_device(
"speaker", sample_rate=self._speaker_sample_rate, channels=1
)

View File

@@ -3,8 +3,9 @@ import os
from dailyai.services.daily_transport_service import DailyTransportService
from dailyai.services.azure_ai_services import AzureLLMService, AzureTTSService
from dailyai.services.ai_services import FrameLogger
from dailyai.queue_aggregators import LLMAssistantContextAggregator, LLMContextAggregator, LLMUserContextAggregator
from examples.foundational.support.runner import configure
from support.runner import configure
async def main(room_url: str, token):
@@ -16,7 +17,8 @@ async def main(room_url: str, token):
start_transcription=True,
mic_enabled=True,
mic_sample_rate=16000,
camera_enabled=False
camera_enabled=False,
vad_enabled=True
)
llm = AzureLLMService(
@@ -26,7 +28,8 @@ async def main(room_url: str, token):
tts = AzureTTSService(
api_key=os.getenv("AZURE_SPEECH_API_KEY"),
region=os.getenv("AZURE_SPEECH_REGION"))
fl = FrameLogger("Inner")
fl2 = FrameLogger("Outer")
@transport.event_handler("on_first_other_participant_joined")
async def on_first_other_participant_joined(transport):
await tts.say("Hi, I'm listening!", transport.send_queue)
@@ -44,14 +47,20 @@ async def main(room_url: str, token):
await tts.run_to_queue(
transport.send_queue,
tma_out.run(
llm.run(
fl2.run(
llm.run(
tma_in.run(
transport.get_receive_frames()
fl.run(
transport.get_receive_frames()
)
)
)
)
)
)
)
)
transport.transcription_settings["extra"]["endpointing"] = True
transport.transcription_settings["extra"]["punctuate"] = True
await asyncio.gather(transport.run(), handle_transcriptions())