diff --git a/pyproject.toml b/pyproject.toml index 7603f8500..e6aa7e5a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,10 +13,13 @@ dependencies = [ "fal", "faster_whisper", "google-cloud-texttospeech", + "numpy", "openai", "Pillow", "pyht", "python-dotenv", + "torch", + "pyaudio", "typing-extensions" ] diff --git a/src/dailyai/queue_frame.py b/src/dailyai/queue_frame.py index d43dbdf82..dc111dcbe 100644 --- a/src/dailyai/queue_frame.py +++ b/src/dailyai/queue_frame.py @@ -58,3 +58,9 @@ class LLMMessagesQueueFrame(QueueFrame): class AppMessageQueueFrame(QueueFrame): message: Any participantId: str + +class UserStartedSpeakingFrame(QueueFrame): + pass + +class UserStoppedSpeakingFrame(QueueFrame): + pass \ No newline at end of file diff --git a/src/dailyai/services/base_transport_service.py b/src/dailyai/services/base_transport_service.py index 530990cfa..9ee60f4bb 100644 --- a/src/dailyai/services/base_transport_service.py +++ b/src/dailyai/services/base_transport_service.py @@ -2,10 +2,15 @@ from abc import abstractmethod import asyncio import itertools import logging +import numpy as np +import pyaudio +import torch +import torchaudio import queue import threading import time from typing import AsyncGenerator +from enum import Enum from dailyai.queue_frame import ( AudioQueueFrame, @@ -14,8 +19,57 @@ from dailyai.queue_frame import ( QueueFrame, SpriteQueueFrame, StartStreamQueueFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame ) +torch.set_num_threads(1) + +model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', + model='silero_vad', + force_reload=False) + +(get_speech_timestamps, + save_audio, + read_audio, + VADIterator, + collect_chunks) = utils + +# Taken from utils_vad.py + + +def validate(model, + inputs: torch.Tensor): + with torch.no_grad(): + outs = model(inputs) + return outs + +# Provided by Alexander Veysov + + +def int2float(sound): + abs_max = np.abs(sound).max() + sound = sound.astype('float32') + if abs_max > 0: + sound *= 1/32768 + sound = sound.squeeze() # depends on the use case + return sound + + +FORMAT = pyaudio.paInt16 +CHANNELS = 1 +SAMPLE_RATE = 16000 +CHUNK = int(SAMPLE_RATE / 10) + +audio = pyaudio.PyAudio() + + +class VADState(Enum): + QUIET = 1 + STARTING = 2 + SPEAKING = 3 + STOPPING = 4 + class BaseTransportService(): @@ -31,7 +85,23 @@ class BaseTransportService(): self._speaker_enabled = kwargs.get("speaker_enabled") or False self._speaker_sample_rate = kwargs.get("speaker_sample_rate") or 16000 self._fps = kwargs.get("fps") or 8 - + self._vad_start_s = kwargs.get("vad_start_s") or 0.2 + self._vad_stop_s = kwargs.get("vad_stop_s") or 0.8 + self._context = kwargs.get("context") or [] + self._vad_enabled = kwargs.get("vad_enabled") or False + + if self._vad_enabled and self._speaker_enabled: + raise Exception("Sorry, you can't use speaker_enabled and vad_enabled at the same time. Please set one to False.") + + self._vad_samples = 1536 + vad_frame_s = self._vad_samples / SAMPLE_RATE + self._vad_start_frames = round(self._vad_start_s / vad_frame_s) + self._vad_stop_frames = round(self._vad_stop_s / vad_frame_s) + self._vad_starting_count = 0 + self._vad_stopping_count = 0 + self._vad_state = VADState.QUIET + self._user_is_speaking = False + duration_minutes = kwargs.get("duration_minutes") or 10 self._expiration = time.time() + duration_minutes * 60 @@ -66,6 +136,10 @@ class BaseTransportService(): if self._speaker_enabled: self._receive_audio_thread = threading.Thread(target=self._receive_audio, daemon=True) self._receive_audio_thread.start() + + if self._vad_enabled: + self._vad_thread = threading.Thread(target=self._vad, daemon=True) + self._vad_thread.start() try: while ( @@ -89,6 +163,10 @@ class BaseTransportService(): if self._speaker_enabled: self._receive_audio_thread.join() + + if self._vad_enabled: + self._vad_thread.join() + def _post_run(self): # Note that this function must be idempotent! It can be called multiple times @@ -121,7 +199,57 @@ class BaseTransportService(): @abstractmethod def _prerun(self): pass - + + def _vad(self): + # CB: Starting silero VAD stuff + # TODO-CB: Probably need to force virtual speaker creation if we're + # going to build this in? + # TODO-CB: pyaudio installation + while not self._stop_threads.is_set(): + audio_chunk = self.read_audio_frames(self._vad_samples) + audio_int16 = np.frombuffer(audio_chunk, np.int16) + audio_float32 = int2float(audio_int16) + new_confidence = model( + torch.from_numpy(audio_float32), 16000).item() + speaking = new_confidence > 0.5 + + if speaking: + match self._vad_state: + case VADState.QUIET: + self._vad_state = VADState.STARTING + self._vad_starting_count = 1 + case VADState.STARTING: + self._vad_starting_count += 1 + case VADState.STOPPING: + self._vad_state = VADState.SPEAKING + self._vad_stopping_count = 0 + else: + match self._vad_state: + case VADState.STARTING: + self._vad_state = VADState.QUIET + self._vad_starting_count = 0 + case VADState.SPEAKING: + self._vad_state = VADState.STOPPING + self._vad_stopping_count = 1 + case VADState.STOPPING: + self._vad_stopping_count += 1 + + if self._vad_state == VADState.STARTING and self._vad_starting_count >= self._vad_start_frames: + asyncio.run_coroutine_threadsafe( + self.receive_queue.put( + UserStartedSpeakingFrame()), self._loop + ) + # self.interrupt() + self._vad_state = VADState.SPEAKING + self._vad_starting_count = 0 + if self._vad_state == VADState.STOPPING and self._vad_stopping_count >= self._vad_stop_frames: + asyncio.run_coroutine_threadsafe( + self.receive_queue.put( + UserStoppedSpeakingFrame()), self._loop + ) + self._vad_state = VADState.QUIET + self._vad_stopping_count = 0 + async def _marshal_frames(self): while True: frame: QueueFrame | list = await self.send_queue.get() diff --git a/src/dailyai/services/daily_transport_service.py b/src/dailyai/services/daily_transport_service.py index c4266406d..2b8416336 100644 --- a/src/dailyai/services/daily_transport_service.py +++ b/src/dailyai/services/daily_transport_service.py @@ -31,6 +31,7 @@ class DailyTransportService(BaseTransportService, EventHandler): _speaker_enabled: bool _speaker_sample_rate: int + _vad_enabled: bool # This is necessary to override EventHandler's __new__ method. def __new__(cls, *args, **kwargs): @@ -142,7 +143,7 @@ class DailyTransportService(BaseTransportService, EventHandler): "camera", width=self._camera_width, height=self._camera_height, color_format="RGB" ) - if self._speaker_enabled: + if self._speaker_enabled or self._vad_enabled: self._speaker: VirtualSpeakerDevice = Daily.create_speaker_device( "speaker", sample_rate=self._speaker_sample_rate, channels=1 ) diff --git a/src/examples/foundational/06-listen-and-respond.py b/src/examples/foundational/06-listen-and-respond.py index fa5e077cc..7cceb607d 100644 --- a/src/examples/foundational/06-listen-and-respond.py +++ b/src/examples/foundational/06-listen-and-respond.py @@ -3,8 +3,9 @@ import os from dailyai.services.daily_transport_service import DailyTransportService from dailyai.services.azure_ai_services import AzureLLMService, AzureTTSService +from dailyai.services.ai_services import FrameLogger from dailyai.queue_aggregators import LLMAssistantContextAggregator, LLMContextAggregator, LLMUserContextAggregator -from examples.foundational.support.runner import configure +from support.runner import configure async def main(room_url: str, token): @@ -16,7 +17,8 @@ async def main(room_url: str, token): start_transcription=True, mic_enabled=True, mic_sample_rate=16000, - camera_enabled=False + camera_enabled=False, + vad_enabled=True ) llm = AzureLLMService( @@ -26,7 +28,8 @@ async def main(room_url: str, token): tts = AzureTTSService( api_key=os.getenv("AZURE_SPEECH_API_KEY"), region=os.getenv("AZURE_SPEECH_REGION")) - + fl = FrameLogger("Inner") + fl2 = FrameLogger("Outer") @transport.event_handler("on_first_other_participant_joined") async def on_first_other_participant_joined(transport): await tts.say("Hi, I'm listening!", transport.send_queue) @@ -44,14 +47,20 @@ async def main(room_url: str, token): await tts.run_to_queue( transport.send_queue, tma_out.run( - llm.run( + fl2.run( + llm.run( tma_in.run( - transport.get_receive_frames() + fl.run( + transport.get_receive_frames() + ) + ) ) ) - ) - ) + ) + ) + + transport.transcription_settings["extra"]["endpointing"] = True transport.transcription_settings["extra"]["punctuate"] = True await asyncio.gather(transport.run(), handle_transcriptions())