Integrate eou and vad
This commit is contained in:
6
processors/__init__.py
Normal file
6
processors/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Audio Processors Package"""
|
||||
|
||||
from processors.eou import EouDetector
|
||||
from processors.vad import SileroVAD, VADProcessor
|
||||
|
||||
__all__ = ["EouDetector", "SileroVAD", "VADProcessor"]
|
||||
80
processors/eou.py
Normal file
80
processors/eou.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""End-of-Utterance Detection."""
|
||||
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class EouDetector:
|
||||
"""
|
||||
End-of-utterance detector. Fires EOU only after continuous silence for
|
||||
silence_threshold_ms. Short pauses between sentences do not trigger EOU
|
||||
because speech resets the silence timer (one EOU per turn).
|
||||
"""
|
||||
|
||||
def __init__(self, silence_threshold_ms: int = 1000, min_speech_duration_ms: int = 250):
|
||||
"""
|
||||
Initialize EOU detector.
|
||||
|
||||
Args:
|
||||
silence_threshold_ms: How long silence must last to trigger EOU (default 1000ms)
|
||||
min_speech_duration_ms: Minimum speech duration to consider valid (default 250ms)
|
||||
"""
|
||||
self.threshold = silence_threshold_ms / 1000.0
|
||||
self.min_speech = min_speech_duration_ms / 1000.0
|
||||
self._silence_threshold_ms = silence_threshold_ms
|
||||
self._min_speech_duration_ms = min_speech_duration_ms
|
||||
|
||||
# State
|
||||
self.is_speaking = False
|
||||
self.speech_start_time = 0.0
|
||||
self.silence_start_time: Optional[float] = None
|
||||
self.triggered = False
|
||||
|
||||
def process(self, vad_status: str) -> bool:
|
||||
"""
|
||||
Process VAD status and detect end of utterance.
|
||||
|
||||
Input: "Speech" or "Silence" (from VAD).
|
||||
Output: True if EOU detected, False otherwise.
|
||||
|
||||
Short breaks between phrases reset the silence clock when speech
|
||||
resumes, so only one EOU is emitted after the user truly stops.
|
||||
"""
|
||||
now = time.time()
|
||||
|
||||
if vad_status == "Speech":
|
||||
if not self.is_speaking:
|
||||
self.is_speaking = True
|
||||
self.speech_start_time = now
|
||||
self.triggered = False
|
||||
# Any speech resets silence timer — short pause + more speech = one utterance
|
||||
self.silence_start_time = None
|
||||
return False
|
||||
|
||||
if vad_status == "Silence":
|
||||
if not self.is_speaking:
|
||||
return False
|
||||
if self.silence_start_time is None:
|
||||
self.silence_start_time = now
|
||||
|
||||
speech_duration = self.silence_start_time - self.speech_start_time
|
||||
if speech_duration < self.min_speech:
|
||||
self.is_speaking = False
|
||||
self.silence_start_time = None
|
||||
return False
|
||||
|
||||
silence_duration = now - self.silence_start_time
|
||||
if silence_duration >= self.threshold and not self.triggered:
|
||||
self.triggered = True
|
||||
self.is_speaking = False
|
||||
self.silence_start_time = None
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset EOU detector state."""
|
||||
self.is_speaking = False
|
||||
self.speech_start_time = 0.0
|
||||
self.silence_start_time = None
|
||||
self.triggered = False
|
||||
168
processors/tracks.py
Normal file
168
processors/tracks.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Audio track processing for WebRTC."""
|
||||
|
||||
import asyncio
|
||||
import fractions
|
||||
from typing import Optional
|
||||
from loguru import logger
|
||||
|
||||
# Try to import aiortc (optional for WebRTC functionality)
|
||||
try:
|
||||
from aiortc import AudioStreamTrack
|
||||
AIORTC_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIORTC_AVAILABLE = False
|
||||
AudioStreamTrack = object # Dummy class for type hints
|
||||
|
||||
# Try to import PyAV (optional for audio resampling)
|
||||
try:
|
||||
from av import AudioFrame, AudioResampler
|
||||
AV_AVAILABLE = True
|
||||
except ImportError:
|
||||
AV_AVAILABLE = False
|
||||
# Create dummy classes for type hints
|
||||
class AudioFrame:
|
||||
pass
|
||||
class AudioResampler:
|
||||
pass
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Resampled16kTrack(AudioStreamTrack if AIORTC_AVAILABLE else object):
|
||||
"""
|
||||
Audio track that resamples input to 16kHz mono PCM.
|
||||
|
||||
Wraps an existing MediaStreamTrack and converts its output
|
||||
to 16kHz mono 16-bit PCM format for the pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, track, target_sample_rate: int = 16000):
|
||||
"""
|
||||
Initialize resampled track.
|
||||
|
||||
Args:
|
||||
track: Source MediaStreamTrack
|
||||
target_sample_rate: Target sample rate (default: 16000)
|
||||
"""
|
||||
if not AIORTC_AVAILABLE:
|
||||
raise RuntimeError("aiortc not available - Resampled16kTrack cannot be used")
|
||||
|
||||
super().__init__()
|
||||
self.track = track
|
||||
self.target_sample_rate = target_sample_rate
|
||||
|
||||
if AV_AVAILABLE:
|
||||
self.resampler = AudioResampler(
|
||||
format="s16",
|
||||
layout="mono",
|
||||
rate=target_sample_rate
|
||||
)
|
||||
else:
|
||||
logger.warning("PyAV not available, audio resampling disabled")
|
||||
self.resampler = None
|
||||
|
||||
self._closed = False
|
||||
|
||||
async def recv(self):
|
||||
"""
|
||||
Receive and resample next audio frame.
|
||||
|
||||
Returns:
|
||||
Resampled AudioFrame at 16kHz mono
|
||||
"""
|
||||
if self._closed:
|
||||
raise RuntimeError("Track is closed")
|
||||
|
||||
# Get frame from source track
|
||||
frame = await self.track.recv()
|
||||
|
||||
# Resample the frame if AV is available
|
||||
if AV_AVAILABLE and self.resampler:
|
||||
resampled_frame = self.resampler.resample(frame)
|
||||
# Ensure the frame has the correct format
|
||||
resampled_frame.sample_rate = self.target_sample_rate
|
||||
return resampled_frame
|
||||
else:
|
||||
# Return frame as-is if AV is not available
|
||||
return frame
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the track and cleanup resources."""
|
||||
self._closed = True
|
||||
if hasattr(self, 'resampler') and self.resampler:
|
||||
del self.resampler
|
||||
logger.debug("Resampled track stopped")
|
||||
|
||||
|
||||
class SineWaveTrack(AudioStreamTrack if AIORTC_AVAILABLE else object):
|
||||
"""
|
||||
Synthetic audio track that generates a sine wave.
|
||||
|
||||
Useful for testing without requiring real audio input.
|
||||
"""
|
||||
|
||||
def __init__(self, sample_rate: int = 16000, frequency: int = 440):
|
||||
"""
|
||||
Initialize sine wave track.
|
||||
|
||||
Args:
|
||||
sample_rate: Audio sample rate (default: 16000)
|
||||
frequency: Sine wave frequency in Hz (default: 440)
|
||||
"""
|
||||
if not AIORTC_AVAILABLE:
|
||||
raise RuntimeError("aiortc not available - SineWaveTrack cannot be used")
|
||||
|
||||
super().__init__()
|
||||
self.sample_rate = sample_rate
|
||||
self.frequency = frequency
|
||||
self.counter = 0
|
||||
self._stopped = False
|
||||
|
||||
async def recv(self):
|
||||
"""
|
||||
Generate next audio frame with sine wave.
|
||||
|
||||
Returns:
|
||||
AudioFrame with sine wave data
|
||||
"""
|
||||
if self._stopped:
|
||||
raise RuntimeError("Track is stopped")
|
||||
|
||||
# Generate 20ms of audio
|
||||
samples = int(self.sample_rate * 0.02)
|
||||
pts = self.counter
|
||||
time_base = fractions.Fraction(1, self.sample_rate)
|
||||
|
||||
# Generate sine wave
|
||||
t = np.linspace(
|
||||
self.counter / self.sample_rate,
|
||||
(self.counter + samples) / self.sample_rate,
|
||||
samples,
|
||||
endpoint=False
|
||||
)
|
||||
|
||||
# Generate sine wave (Int16 PCM)
|
||||
data = (0.5 * np.sin(2 * np.pi * self.frequency * t) * 32767).astype(np.int16)
|
||||
|
||||
# Update counter
|
||||
self.counter += samples
|
||||
|
||||
# Create AudioFrame if AV is available
|
||||
if AV_AVAILABLE:
|
||||
frame = AudioFrame.from_ndarray(data.reshape(1, -1), format='s16', layout='mono')
|
||||
frame.pts = pts
|
||||
frame.time_base = time_base
|
||||
frame.sample_rate = self.sample_rate
|
||||
return frame
|
||||
else:
|
||||
# Return simple data structure if AV is not available
|
||||
return {
|
||||
'data': data,
|
||||
'sample_rate': self.sample_rate,
|
||||
'pts': pts,
|
||||
'time_base': time_base
|
||||
}
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the track."""
|
||||
self._stopped = True
|
||||
213
processors/vad.py
Normal file
213
processors/vad.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Voice Activity Detection using Silero VAD."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Tuple, Optional
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from processors.eou import EouDetector
|
||||
|
||||
# Try to import onnxruntime (optional for VAD functionality)
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
ONNX_AVAILABLE = True
|
||||
except ImportError:
|
||||
ONNX_AVAILABLE = False
|
||||
ort = None
|
||||
logger.warning("onnxruntime not available - VAD will be disabled")
|
||||
|
||||
|
||||
class SileroVAD:
|
||||
"""
|
||||
Voice Activity Detection using Silero VAD model.
|
||||
|
||||
Detects speech in audio chunks using the Silero VAD ONNX model.
|
||||
Returns "Speech" or "Silence" for each audio chunk.
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str = "data/vad/silero_vad.onnx", sample_rate: int = 16000):
|
||||
"""
|
||||
Initialize Silero VAD.
|
||||
|
||||
Args:
|
||||
model_path: Path to Silero VAD ONNX model
|
||||
sample_rate: Audio sample rate (must be 16kHz for Silero VAD)
|
||||
"""
|
||||
self.sample_rate = sample_rate
|
||||
self.model_path = model_path
|
||||
|
||||
# Check if model exists
|
||||
if not os.path.exists(model_path):
|
||||
logger.warning(f"VAD model not found at {model_path}. VAD will be disabled.")
|
||||
self.session = None
|
||||
return
|
||||
|
||||
# Check if onnxruntime is available
|
||||
if not ONNX_AVAILABLE:
|
||||
logger.warning("onnxruntime not available - VAD will be disabled")
|
||||
self.session = None
|
||||
return
|
||||
|
||||
# Load ONNX model
|
||||
try:
|
||||
self.session = ort.InferenceSession(model_path)
|
||||
logger.info(f"Loaded Silero VAD model from {model_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load VAD model: {e}")
|
||||
self.session = None
|
||||
return
|
||||
|
||||
# Internal state for VAD
|
||||
self._reset_state()
|
||||
self.buffer = np.array([], dtype=np.float32)
|
||||
self.min_chunk_size = 512
|
||||
self.last_label = "Silence"
|
||||
self.last_probability = 0.0
|
||||
|
||||
def _reset_state(self):
|
||||
# Silero VAD V4+ expects state shape [2, 1, 128]
|
||||
self._state = np.zeros((2, 1, 128), dtype=np.float32)
|
||||
self._sr = np.array([self.sample_rate], dtype=np.int64)
|
||||
|
||||
def process_audio(self, pcm_bytes: bytes, chunk_size_ms: int = 20) -> Tuple[str, float]:
|
||||
"""
|
||||
Process audio chunk and detect speech.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
chunk_size_ms: Chunk duration in milliseconds (ignored for buffering logic)
|
||||
|
||||
Returns:
|
||||
Tuple of (label, probability) where label is "Speech" or "Silence"
|
||||
"""
|
||||
if self.session is None or not ONNX_AVAILABLE:
|
||||
# If model not loaded or onnxruntime not available, assume speech
|
||||
return "Speech", 1.0
|
||||
|
||||
# Convert bytes to numpy array of int16
|
||||
audio_int16 = np.frombuffer(pcm_bytes, dtype=np.int16)
|
||||
|
||||
# Normalize to float32 (-1.0 to 1.0)
|
||||
audio_float = audio_int16.astype(np.float32) / 32768.0
|
||||
|
||||
# Add to buffer
|
||||
self.buffer = np.concatenate((self.buffer, audio_float))
|
||||
|
||||
# Process all complete chunks in the buffer
|
||||
processed_any = False
|
||||
while len(self.buffer) >= self.min_chunk_size:
|
||||
# Slice exactly 512 samples
|
||||
chunk = self.buffer[:self.min_chunk_size]
|
||||
self.buffer = self.buffer[self.min_chunk_size:]
|
||||
|
||||
# Prepare inputs
|
||||
# Input tensor shape: [batch, samples] -> [1, 512]
|
||||
input_tensor = chunk.reshape(1, -1)
|
||||
|
||||
# Run inference
|
||||
try:
|
||||
ort_inputs = {
|
||||
'input': input_tensor,
|
||||
'state': self._state,
|
||||
'sr': self._sr
|
||||
}
|
||||
|
||||
# Outputs: probability, state
|
||||
out, self._state = self.session.run(None, ort_inputs)
|
||||
|
||||
# Get probability
|
||||
self.last_probability = float(out[0][0])
|
||||
self.last_label = "Speech" if self.last_probability >= 0.5 else "Silence"
|
||||
processed_any = True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"VAD inference error: {e}")
|
||||
# Try to determine if it's an input name issue
|
||||
try:
|
||||
inputs = [x.name for x in self.session.get_inputs()]
|
||||
logger.error(f"Model expects inputs: {inputs}")
|
||||
except:
|
||||
pass
|
||||
return "Speech", 1.0
|
||||
|
||||
return self.last_label, self.last_probability
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset VAD internal state."""
|
||||
self._reset_state()
|
||||
self.buffer = np.array([], dtype=np.float32)
|
||||
self.last_label = "Silence"
|
||||
self.last_probability = 0.0
|
||||
|
||||
|
||||
class VADProcessor:
|
||||
"""
|
||||
High-level VAD processor with state management.
|
||||
|
||||
Tracks speech/silence state and emits events on transitions.
|
||||
"""
|
||||
|
||||
def __init__(self, vad_model: SileroVAD, threshold: float = 0.5,
|
||||
silence_threshold_ms: int = 1000, min_speech_duration_ms: int = 250):
|
||||
"""
|
||||
Initialize VAD processor.
|
||||
|
||||
Args:
|
||||
vad_model: Silero VAD model instance
|
||||
threshold: Speech detection threshold
|
||||
silence_threshold_ms: EOU silence threshold in ms (longer = one EOU across short pauses)
|
||||
min_speech_duration_ms: EOU min speech duration in ms (ignore very short noises)
|
||||
"""
|
||||
self.vad = vad_model
|
||||
self.threshold = threshold
|
||||
self._eou_silence_ms = silence_threshold_ms
|
||||
self._eou_min_speech_ms = min_speech_duration_ms
|
||||
self.is_speaking = False
|
||||
self.speech_start_time: Optional[float] = None
|
||||
self.silence_start_time: Optional[float] = None
|
||||
self.eou_detector = EouDetector(silence_threshold_ms, min_speech_duration_ms)
|
||||
|
||||
def process(self, pcm_bytes: bytes, chunk_size_ms: int = 20) -> Optional[Tuple[str, float]]:
|
||||
"""
|
||||
Process audio chunk and detect state changes.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data
|
||||
chunk_size_ms: Chunk duration in milliseconds
|
||||
|
||||
Returns:
|
||||
Tuple of (event_type, probability) if state changed, None otherwise
|
||||
"""
|
||||
label, probability = self.vad.process_audio(pcm_bytes, chunk_size_ms)
|
||||
|
||||
# Check if this is speech based on threshold
|
||||
is_speech = probability >= self.threshold
|
||||
|
||||
# Check EOU
|
||||
if self.eou_detector.process("Speech" if is_speech else "Silence"):
|
||||
return ("eou", probability)
|
||||
|
||||
# State transition: Silence -> Speech
|
||||
if is_speech and not self.is_speaking:
|
||||
self.is_speaking = True
|
||||
self.speech_start_time = asyncio.get_event_loop().time()
|
||||
self.silence_start_time = None
|
||||
return ("speaking", probability)
|
||||
|
||||
# State transition: Speech -> Silence
|
||||
elif not is_speech and self.is_speaking:
|
||||
self.is_speaking = False
|
||||
self.silence_start_time = asyncio.get_event_loop().time()
|
||||
self.speech_start_time = None
|
||||
return ("silence", probability)
|
||||
|
||||
return None
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset VAD state."""
|
||||
self.vad.reset()
|
||||
self.is_speaking = False
|
||||
self.speech_start_time = None
|
||||
self.silence_start_time = None
|
||||
self.eou_detector = EouDetector(self._eou_silence_ms, self._eou_min_speech_ms)
|
||||
Reference in New Issue
Block a user