Adding some comments to the code.

This commit is contained in:
Filipi Fuchter
2025-04-16 08:58:40 -03:00
parent 5fa47b7a5c
commit cd8bd7f487

View File

@@ -4,7 +4,6 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
import time
from abc import ABC, abstractmethod
from enum import Enum
@@ -15,14 +14,16 @@ from loguru import logger
from pydantic import BaseModel
# Enum for end-of-turn detection states
class EndOfTurnState(Enum):
COMPLETE = 1
INCOMPLETE = 2
# Default timing parameters
STOP_SECS = 1
PRE_SPEECH_MS = 0
MAX_DURATION_SECONDS = 8 # Maximum duration for the smart turn model
MAX_DURATION_SECONDS = 8 # Max allowed segment duration
class SmartTurnParams(BaseModel):
@@ -37,11 +38,11 @@ class BaseSmartTurn(ABC):
):
self._init_sample_rate = sample_rate
self._params = params
# settings variables
# Configuration
self._sample_rate = 0
self._chunk_size_ms = 0
self._stop_ms = self._params.stop_secs * 1000
# inference variables
self._stop_ms = self._params.stop_secs * 1000 # silence threshold in ms
# Inference state
self._audio_buffer = []
self._speech_triggered = False
self._silence_frames = 0
@@ -52,7 +53,7 @@ class BaseSmartTurn(ABC):
return self._sample_rate
def set_sample_rate(self, sample_rate: int):
self._sample_rate = self._init_sample_rate or sample_rate
self._sample_rate = sample_rate
@property
def chunk_size_ms(self) -> int:
@@ -62,13 +63,15 @@ class BaseSmartTurn(ABC):
self._chunk_size_ms = chunk_size_ms
def append_audio(self, buffer: bytes, is_speech: bool) -> EndOfTurnState:
# Convert raw audio to float32 format and append to the buffer
audio_int16 = np.frombuffer(buffer, dtype=np.int16)
# Divide by 32768 because we have signed 16-bit data.
audio_float32 = np.frombuffer(audio_int16, dtype=np.int16).astype(np.float32) / 32768.0
self._audio_buffer.append((time.time(), audio_float32))
state = EndOfTurnState.INCOMPLETE
if is_speech:
# Reset silence tracking on speech
self._silence_frames = 0
self._speech_triggered = True
if self._speech_start_time is None:
@@ -77,15 +80,18 @@ class BaseSmartTurn(ABC):
else:
if self._speech_triggered:
self._silence_frames += 1
# If silence exceeds threshold, mark end of turn
if self._silence_frames * self._chunk_size_ms >= self._stop_ms:
logger.debug("End of Turn complete due to stop_secs.")
state = EndOfTurnState.COMPLETE
self._clear()
else:
# Keep the buffer size reasonable, assuming CHUNK is small
# Trim buffer to prevent unbounded growth before speech
max_buffer_time = (
self._params.pre_speech_ms + self._stop_ms
) / 1000 + self._params.max_duration_secs # Some extra buffer
(self._params.pre_speech_ms / 1000)
+ self._params.stop_secs
+ self._params.max_duration_secs
)
while (
self._audio_buffer and self._audio_buffer[0][0] < time.time() - max_buffer_time
):
@@ -98,11 +104,11 @@ class BaseSmartTurn(ABC):
state = self._process_speech_segment(self._audio_buffer)
if state == EndOfTurnState.COMPLETE:
self._clear()
logger.debug(f"End of Turn result: {state}")
return state
def _clear(self):
# Reset internal state for next turn
logger.debug("Clearing audio buffer...")
self._speech_triggered = False
self._audio_buffer = []
@@ -115,7 +121,7 @@ class BaseSmartTurn(ABC):
if not audio_buffer:
return state
# Find start and end indices for the segment
# Extract recent audio segment for prediction
start_time = self._speech_start_time - (self._params.pre_speech_ms / 1000)
start_index = 0
for i, (t, _) in enumerate(audio_buffer):
@@ -137,17 +143,12 @@ class BaseSmartTurn(ABC):
logger.debug(f"Segment audio chunks after limiting duration: {len(segment_audio)}")
# No resampling needed as both recording and prediction use 16000 Hz
if len(segment_audio) > 0:
# Call the new predict_endpoint function with the audio data
start_time = time.perf_counter()
result = self._predict_endpoint(segment_audio)
state = (
EndOfTurnState.COMPLETE if result["prediction"] == 1 else EndOfTurnState.INCOMPLETE
)
end_time = time.perf_counter()
logger.debug("--------")
@@ -163,14 +164,14 @@ class BaseSmartTurn(ABC):
@abstractmethod
def _predict_endpoint(self, buffer: np.ndarray) -> Dict[str, any]:
"""
Predict whether an audio segment is complete (turn ended) or incomplete.
Abstract method to predict if a turn has ended based on audio.
Args:
audio_array: Numpy array containing audio samples at 16kHz
buffer: Float32 numpy array of audio samples at 16kHz.
Returns:
Dictionary containing prediction results:
- prediction: 1 for complete, 0 for incomplete
- probability: Probability of completion class
Dictionary with:
- prediction: 1 if turn is complete, else 0
- probability: Confidence of the prediction
"""
pass