Adding some comments to the code.
This commit is contained in:
@@ -4,7 +4,6 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
@@ -15,14 +14,16 @@ from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# Enum for end-of-turn detection states
|
||||
class EndOfTurnState(Enum):
|
||||
COMPLETE = 1
|
||||
INCOMPLETE = 2
|
||||
|
||||
|
||||
# Default timing parameters
|
||||
STOP_SECS = 1
|
||||
PRE_SPEECH_MS = 0
|
||||
MAX_DURATION_SECONDS = 8 # Maximum duration for the smart turn model
|
||||
MAX_DURATION_SECONDS = 8 # Max allowed segment duration
|
||||
|
||||
|
||||
class SmartTurnParams(BaseModel):
|
||||
@@ -37,11 +38,11 @@ class BaseSmartTurn(ABC):
|
||||
):
|
||||
self._init_sample_rate = sample_rate
|
||||
self._params = params
|
||||
# settings variables
|
||||
# Configuration
|
||||
self._sample_rate = 0
|
||||
self._chunk_size_ms = 0
|
||||
self._stop_ms = self._params.stop_secs * 1000
|
||||
# inference variables
|
||||
self._stop_ms = self._params.stop_secs * 1000 # silence threshold in ms
|
||||
# Inference state
|
||||
self._audio_buffer = []
|
||||
self._speech_triggered = False
|
||||
self._silence_frames = 0
|
||||
@@ -52,7 +53,7 @@ class BaseSmartTurn(ABC):
|
||||
return self._sample_rate
|
||||
|
||||
def set_sample_rate(self, sample_rate: int):
|
||||
self._sample_rate = self._init_sample_rate or sample_rate
|
||||
self._sample_rate = sample_rate
|
||||
|
||||
@property
|
||||
def chunk_size_ms(self) -> int:
|
||||
@@ -62,13 +63,15 @@ class BaseSmartTurn(ABC):
|
||||
self._chunk_size_ms = chunk_size_ms
|
||||
|
||||
def append_audio(self, buffer: bytes, is_speech: bool) -> EndOfTurnState:
|
||||
# Convert raw audio to float32 format and append to the buffer
|
||||
audio_int16 = np.frombuffer(buffer, dtype=np.int16)
|
||||
# Divide by 32768 because we have signed 16-bit data.
|
||||
audio_float32 = np.frombuffer(audio_int16, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
self._audio_buffer.append((time.time(), audio_float32))
|
||||
|
||||
state = EndOfTurnState.INCOMPLETE
|
||||
|
||||
if is_speech:
|
||||
# Reset silence tracking on speech
|
||||
self._silence_frames = 0
|
||||
self._speech_triggered = True
|
||||
if self._speech_start_time is None:
|
||||
@@ -77,15 +80,18 @@ class BaseSmartTurn(ABC):
|
||||
else:
|
||||
if self._speech_triggered:
|
||||
self._silence_frames += 1
|
||||
# If silence exceeds threshold, mark end of turn
|
||||
if self._silence_frames * self._chunk_size_ms >= self._stop_ms:
|
||||
logger.debug("End of Turn complete due to stop_secs.")
|
||||
state = EndOfTurnState.COMPLETE
|
||||
self._clear()
|
||||
else:
|
||||
# Keep the buffer size reasonable, assuming CHUNK is small
|
||||
# Trim buffer to prevent unbounded growth before speech
|
||||
max_buffer_time = (
|
||||
self._params.pre_speech_ms + self._stop_ms
|
||||
) / 1000 + self._params.max_duration_secs # Some extra buffer
|
||||
(self._params.pre_speech_ms / 1000)
|
||||
+ self._params.stop_secs
|
||||
+ self._params.max_duration_secs
|
||||
)
|
||||
while (
|
||||
self._audio_buffer and self._audio_buffer[0][0] < time.time() - max_buffer_time
|
||||
):
|
||||
@@ -98,11 +104,11 @@ class BaseSmartTurn(ABC):
|
||||
state = self._process_speech_segment(self._audio_buffer)
|
||||
if state == EndOfTurnState.COMPLETE:
|
||||
self._clear()
|
||||
|
||||
logger.debug(f"End of Turn result: {state}")
|
||||
return state
|
||||
|
||||
def _clear(self):
|
||||
# Reset internal state for next turn
|
||||
logger.debug("Clearing audio buffer...")
|
||||
self._speech_triggered = False
|
||||
self._audio_buffer = []
|
||||
@@ -115,7 +121,7 @@ class BaseSmartTurn(ABC):
|
||||
if not audio_buffer:
|
||||
return state
|
||||
|
||||
# Find start and end indices for the segment
|
||||
# Extract recent audio segment for prediction
|
||||
start_time = self._speech_start_time - (self._params.pre_speech_ms / 1000)
|
||||
start_index = 0
|
||||
for i, (t, _) in enumerate(audio_buffer):
|
||||
@@ -137,17 +143,12 @@ class BaseSmartTurn(ABC):
|
||||
|
||||
logger.debug(f"Segment audio chunks after limiting duration: {len(segment_audio)}")
|
||||
|
||||
# No resampling needed as both recording and prediction use 16000 Hz
|
||||
if len(segment_audio) > 0:
|
||||
# Call the new predict_endpoint function with the audio data
|
||||
start_time = time.perf_counter()
|
||||
|
||||
result = self._predict_endpoint(segment_audio)
|
||||
|
||||
state = (
|
||||
EndOfTurnState.COMPLETE if result["prediction"] == 1 else EndOfTurnState.INCOMPLETE
|
||||
)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
|
||||
logger.debug("--------")
|
||||
@@ -163,14 +164,14 @@ class BaseSmartTurn(ABC):
|
||||
@abstractmethod
|
||||
def _predict_endpoint(self, buffer: np.ndarray) -> Dict[str, any]:
|
||||
"""
|
||||
Predict whether an audio segment is complete (turn ended) or incomplete.
|
||||
Abstract method to predict if a turn has ended based on audio.
|
||||
|
||||
Args:
|
||||
audio_array: Numpy array containing audio samples at 16kHz
|
||||
buffer: Float32 numpy array of audio samples at 16kHz.
|
||||
|
||||
Returns:
|
||||
Dictionary containing prediction results:
|
||||
- prediction: 1 for complete, 0 for incomplete
|
||||
- probability: Probability of completion class
|
||||
Dictionary with:
|
||||
- prediction: 1 if turn is complete, else 0
|
||||
- probability: Confidence of the prediction
|
||||
"""
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user