Simplify VAD events to be detected and emitted from BaseInputTransport

This commit is contained in:
Mark Backman
2025-04-25 10:28:44 -04:00
parent b298376766
commit dfa10af6ed
2 changed files with 11 additions and 36 deletions

View File

@@ -6,7 +6,7 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional, Tuple
from typing import Optional
from loguru import logger
from pydantic import BaseModel
@@ -88,24 +88,12 @@ class VADAnalyzer(ABC):
volume = calculate_audio_volume(audio, self.sample_rate)
return exp_smoothing(volume, self._prev_volume, self._smoothing_factor)
def analyze_audio(self, buffer) -> Tuple[VADState, Optional[str]]:
"""Analyze audio for voice activity.
Args:
buffer: Audio buffer to analyze
Returns:
Tuple containing:
- VADState: Current VAD state
- Optional[str]: Event type if a speech event occurred ("speech_started",
"speech_stopped"), or None if no event occurred
"""
def analyze_audio(self, buffer) -> VADState:
self._vad_buffer += buffer
event_type = None
num_required_bytes = self._vad_frames_num_bytes
if len(self._vad_buffer) < num_required_bytes:
return self._vad_state, event_type
return self._vad_state
audio_frames = self._vad_buffer[:num_required_bytes]
self._vad_buffer = self._vad_buffer[num_required_bytes:]
@@ -144,7 +132,6 @@ class VADAnalyzer(ABC):
):
self._vad_state = VADState.SPEAKING
self._vad_starting_count = 0
event_type = "speech_started"
if (
self._vad_state == VADState.STOPPING
@@ -152,6 +139,5 @@ class VADAnalyzer(ABC):
):
self._vad_state = VADState.QUIET
self._vad_stopping_count = 0
event_type = "speech_stopped"
return self._vad_state, event_type
return self._vad_state

View File

@@ -230,13 +230,9 @@ class BaseInputTransport(FrameProcessor):
async def _vad_analyze(self, audio_frame: InputAudioRawFrame) -> VADState:
state = VADState.QUIET
if self.vad_analyzer:
state, event_type = await self.get_event_loop().run_in_executor(
state = await self.get_event_loop().run_in_executor(
self._executor, self.vad_analyzer.analyze_audio, audio_frame.audio
)
if event_type:
await self._handle_vad_event(event_type)
return state
async def _handle_vad(self, audio_frame: InputAudioRawFrame, vad_state: VADState):
@@ -254,10 +250,13 @@ class BaseInputTransport(FrameProcessor):
self._params.turn_analyzer is None
or not self._params.turn_analyzer.speech_triggered
)
if can_create_user_frames:
if new_vad_state == VADState.SPEAKING:
if new_vad_state == VADState.SPEAKING:
await self.push_frame(VADUserStartedSpeakingFrame())
if can_create_user_frames:
frame = UserStartedSpeakingFrame()
elif new_vad_state == VADState.QUIET:
elif new_vad_state == VADState.QUIET:
await self.push_frame(VADUserStoppedSpeakingFrame())
if can_create_user_frames:
frame = UserStoppedSpeakingFrame()
if frame:
@@ -266,16 +265,6 @@ class BaseInputTransport(FrameProcessor):
vad_state = new_vad_state
return vad_state
async def _handle_vad_event(self, event_type: str):
"""Handle VAD speech events by creating and pushing appropriate frames."""
if event_type == "speech_started":
logger.debug("VAD detected definitive speech start")
await self.push_frame(VADUserStartedSpeakingFrame())
elif event_type == "speech_stopped":
logger.debug("VAD detected definitive speech stop")
await self.push_frame(VADUserStoppedSpeakingFrame())
async def _handle_end_of_turn(self):
if self.turn_analyzer:
state, prediction = await self.turn_analyzer.analyze_end_of_turn()