Isolated changes to add VAD (#32)
* added VAD * added separate 'vad enabled' property
This commit is contained in:
@@ -13,10 +13,13 @@ dependencies = [
|
||||
"fal",
|
||||
"faster_whisper",
|
||||
"google-cloud-texttospeech",
|
||||
"numpy",
|
||||
"openai",
|
||||
"Pillow",
|
||||
"pyht",
|
||||
"python-dotenv",
|
||||
"torch",
|
||||
"pyaudio",
|
||||
"typing-extensions"
|
||||
]
|
||||
|
||||
|
||||
@@ -58,3 +58,9 @@ class LLMMessagesQueueFrame(QueueFrame):
|
||||
class AppMessageQueueFrame(QueueFrame):
|
||||
message: Any
|
||||
participantId: str
|
||||
|
||||
class UserStartedSpeakingFrame(QueueFrame):
|
||||
pass
|
||||
|
||||
class UserStoppedSpeakingFrame(QueueFrame):
|
||||
pass
|
||||
@@ -2,10 +2,15 @@ from abc import abstractmethod
|
||||
import asyncio
|
||||
import itertools
|
||||
import logging
|
||||
import numpy as np
|
||||
import pyaudio
|
||||
import torch
|
||||
import torchaudio
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from typing import AsyncGenerator
|
||||
from enum import Enum
|
||||
|
||||
from dailyai.queue_frame import (
|
||||
AudioQueueFrame,
|
||||
@@ -14,8 +19,57 @@ from dailyai.queue_frame import (
|
||||
QueueFrame,
|
||||
SpriteQueueFrame,
|
||||
StartStreamQueueFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame
|
||||
)
|
||||
|
||||
torch.set_num_threads(1)
|
||||
|
||||
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
||||
model='silero_vad',
|
||||
force_reload=False)
|
||||
|
||||
(get_speech_timestamps,
|
||||
save_audio,
|
||||
read_audio,
|
||||
VADIterator,
|
||||
collect_chunks) = utils
|
||||
|
||||
# Taken from utils_vad.py
|
||||
|
||||
|
||||
def validate(model,
|
||||
inputs: torch.Tensor):
|
||||
with torch.no_grad():
|
||||
outs = model(inputs)
|
||||
return outs
|
||||
|
||||
# Provided by Alexander Veysov
|
||||
|
||||
|
||||
def int2float(sound):
|
||||
abs_max = np.abs(sound).max()
|
||||
sound = sound.astype('float32')
|
||||
if abs_max > 0:
|
||||
sound *= 1/32768
|
||||
sound = sound.squeeze() # depends on the use case
|
||||
return sound
|
||||
|
||||
|
||||
FORMAT = pyaudio.paInt16
|
||||
CHANNELS = 1
|
||||
SAMPLE_RATE = 16000
|
||||
CHUNK = int(SAMPLE_RATE / 10)
|
||||
|
||||
audio = pyaudio.PyAudio()
|
||||
|
||||
|
||||
class VADState(Enum):
|
||||
QUIET = 1
|
||||
STARTING = 2
|
||||
SPEAKING = 3
|
||||
STOPPING = 4
|
||||
|
||||
|
||||
class BaseTransportService():
|
||||
|
||||
@@ -31,7 +85,23 @@ class BaseTransportService():
|
||||
self._speaker_enabled = kwargs.get("speaker_enabled") or False
|
||||
self._speaker_sample_rate = kwargs.get("speaker_sample_rate") or 16000
|
||||
self._fps = kwargs.get("fps") or 8
|
||||
|
||||
self._vad_start_s = kwargs.get("vad_start_s") or 0.2
|
||||
self._vad_stop_s = kwargs.get("vad_stop_s") or 0.8
|
||||
self._context = kwargs.get("context") or []
|
||||
self._vad_enabled = kwargs.get("vad_enabled") or False
|
||||
|
||||
if self._vad_enabled and self._speaker_enabled:
|
||||
raise Exception("Sorry, you can't use speaker_enabled and vad_enabled at the same time. Please set one to False.")
|
||||
|
||||
self._vad_samples = 1536
|
||||
vad_frame_s = self._vad_samples / SAMPLE_RATE
|
||||
self._vad_start_frames = round(self._vad_start_s / vad_frame_s)
|
||||
self._vad_stop_frames = round(self._vad_stop_s / vad_frame_s)
|
||||
self._vad_starting_count = 0
|
||||
self._vad_stopping_count = 0
|
||||
self._vad_state = VADState.QUIET
|
||||
self._user_is_speaking = False
|
||||
|
||||
duration_minutes = kwargs.get("duration_minutes") or 10
|
||||
self._expiration = time.time() + duration_minutes * 60
|
||||
|
||||
@@ -66,6 +136,10 @@ class BaseTransportService():
|
||||
if self._speaker_enabled:
|
||||
self._receive_audio_thread = threading.Thread(target=self._receive_audio, daemon=True)
|
||||
self._receive_audio_thread.start()
|
||||
|
||||
if self._vad_enabled:
|
||||
self._vad_thread = threading.Thread(target=self._vad, daemon=True)
|
||||
self._vad_thread.start()
|
||||
|
||||
try:
|
||||
while (
|
||||
@@ -89,6 +163,10 @@ class BaseTransportService():
|
||||
|
||||
if self._speaker_enabled:
|
||||
self._receive_audio_thread.join()
|
||||
|
||||
if self._vad_enabled:
|
||||
self._vad_thread.join()
|
||||
|
||||
|
||||
def _post_run(self):
|
||||
# Note that this function must be idempotent! It can be called multiple times
|
||||
@@ -121,7 +199,57 @@ class BaseTransportService():
|
||||
@abstractmethod
|
||||
def _prerun(self):
|
||||
pass
|
||||
|
||||
|
||||
def _vad(self):
|
||||
# CB: Starting silero VAD stuff
|
||||
# TODO-CB: Probably need to force virtual speaker creation if we're
|
||||
# going to build this in?
|
||||
# TODO-CB: pyaudio installation
|
||||
while not self._stop_threads.is_set():
|
||||
audio_chunk = self.read_audio_frames(self._vad_samples)
|
||||
audio_int16 = np.frombuffer(audio_chunk, np.int16)
|
||||
audio_float32 = int2float(audio_int16)
|
||||
new_confidence = model(
|
||||
torch.from_numpy(audio_float32), 16000).item()
|
||||
speaking = new_confidence > 0.5
|
||||
|
||||
if speaking:
|
||||
match self._vad_state:
|
||||
case VADState.QUIET:
|
||||
self._vad_state = VADState.STARTING
|
||||
self._vad_starting_count = 1
|
||||
case VADState.STARTING:
|
||||
self._vad_starting_count += 1
|
||||
case VADState.STOPPING:
|
||||
self._vad_state = VADState.SPEAKING
|
||||
self._vad_stopping_count = 0
|
||||
else:
|
||||
match self._vad_state:
|
||||
case VADState.STARTING:
|
||||
self._vad_state = VADState.QUIET
|
||||
self._vad_starting_count = 0
|
||||
case VADState.SPEAKING:
|
||||
self._vad_state = VADState.STOPPING
|
||||
self._vad_stopping_count = 1
|
||||
case VADState.STOPPING:
|
||||
self._vad_stopping_count += 1
|
||||
|
||||
if self._vad_state == VADState.STARTING and self._vad_starting_count >= self._vad_start_frames:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.receive_queue.put(
|
||||
UserStartedSpeakingFrame()), self._loop
|
||||
)
|
||||
# self.interrupt()
|
||||
self._vad_state = VADState.SPEAKING
|
||||
self._vad_starting_count = 0
|
||||
if self._vad_state == VADState.STOPPING and self._vad_stopping_count >= self._vad_stop_frames:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.receive_queue.put(
|
||||
UserStoppedSpeakingFrame()), self._loop
|
||||
)
|
||||
self._vad_state = VADState.QUIET
|
||||
self._vad_stopping_count = 0
|
||||
|
||||
async def _marshal_frames(self):
|
||||
while True:
|
||||
frame: QueueFrame | list = await self.send_queue.get()
|
||||
|
||||
@@ -31,6 +31,7 @@ class DailyTransportService(BaseTransportService, EventHandler):
|
||||
|
||||
_speaker_enabled: bool
|
||||
_speaker_sample_rate: int
|
||||
_vad_enabled: bool
|
||||
|
||||
# This is necessary to override EventHandler's __new__ method.
|
||||
def __new__(cls, *args, **kwargs):
|
||||
@@ -142,7 +143,7 @@ class DailyTransportService(BaseTransportService, EventHandler):
|
||||
"camera", width=self._camera_width, height=self._camera_height, color_format="RGB"
|
||||
)
|
||||
|
||||
if self._speaker_enabled:
|
||||
if self._speaker_enabled or self._vad_enabled:
|
||||
self._speaker: VirtualSpeakerDevice = Daily.create_speaker_device(
|
||||
"speaker", sample_rate=self._speaker_sample_rate, channels=1
|
||||
)
|
||||
|
||||
@@ -3,8 +3,9 @@ import os
|
||||
|
||||
from dailyai.services.daily_transport_service import DailyTransportService
|
||||
from dailyai.services.azure_ai_services import AzureLLMService, AzureTTSService
|
||||
from dailyai.services.ai_services import FrameLogger
|
||||
from dailyai.queue_aggregators import LLMAssistantContextAggregator, LLMContextAggregator, LLMUserContextAggregator
|
||||
from examples.foundational.support.runner import configure
|
||||
from support.runner import configure
|
||||
|
||||
|
||||
async def main(room_url: str, token):
|
||||
@@ -16,7 +17,8 @@ async def main(room_url: str, token):
|
||||
start_transcription=True,
|
||||
mic_enabled=True,
|
||||
mic_sample_rate=16000,
|
||||
camera_enabled=False
|
||||
camera_enabled=False,
|
||||
vad_enabled=True
|
||||
)
|
||||
|
||||
llm = AzureLLMService(
|
||||
@@ -26,7 +28,8 @@ async def main(room_url: str, token):
|
||||
tts = AzureTTSService(
|
||||
api_key=os.getenv("AZURE_SPEECH_API_KEY"),
|
||||
region=os.getenv("AZURE_SPEECH_REGION"))
|
||||
|
||||
fl = FrameLogger("Inner")
|
||||
fl2 = FrameLogger("Outer")
|
||||
@transport.event_handler("on_first_other_participant_joined")
|
||||
async def on_first_other_participant_joined(transport):
|
||||
await tts.say("Hi, I'm listening!", transport.send_queue)
|
||||
@@ -44,14 +47,20 @@ async def main(room_url: str, token):
|
||||
await tts.run_to_queue(
|
||||
transport.send_queue,
|
||||
tma_out.run(
|
||||
llm.run(
|
||||
fl2.run(
|
||||
llm.run(
|
||||
tma_in.run(
|
||||
transport.get_receive_frames()
|
||||
fl.run(
|
||||
transport.get_receive_frames()
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
)
|
||||
)
|
||||
|
||||
transport.transcription_settings["extra"]["endpointing"] = True
|
||||
transport.transcription_settings["extra"]["punctuate"] = True
|
||||
await asyncio.gather(transport.run(), handle_transcriptions())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user