Refactoring the BaseEndOfTurnAnalyzer to include most of the logic

This commit is contained in:
Filipi Fuchter
2025-04-16 06:11:56 -03:00
parent cfefcac35f
commit 3e2d21779f
2 changed files with 122 additions and 119 deletions

View File

@@ -5,9 +5,13 @@
#
import time
from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional
from typing import Dict, Optional
import numpy as np
from loguru import logger
class EndOfTurnState(Enum):
@@ -15,11 +19,23 @@ class EndOfTurnState(Enum):
INCOMPLETE = 2
# TODO: we should convert all this to params
STOP_MS = 1000
PRE_SPEECH_MS = 200
MAX_DURATION_SECONDS = 8 # Maximum duration for the smart turn model
class BaseEndOfTurnAnalyzer(ABC):
def __init__(self, *, sample_rate: Optional[int] = None):
self._init_sample_rate = sample_rate
# settings variables
self._sample_rate = 0
self._chunk_size_ms = 0
# inference variables
self._audio_buffer = []
self._speech_triggered = False
self._silence_frames = 0
self._speech_start_time = None
@property
def sample_rate(self) -> int:
@@ -35,10 +51,107 @@ class BaseEndOfTurnAnalyzer(ABC):
def set_chunk_size_ms(self, chunk_size_ms: int):
self._chunk_size_ms = chunk_size_ms
@abstractmethod
def append_audio(self, buffer: bytes, is_speech: bool) -> EndOfTurnState:
pass
audio_int16 = np.frombuffer(buffer, dtype=np.int16)
# Divide by 32768 because we have signed 16-bit data.
audio_float32 = np.frombuffer(audio_int16, dtype=np.int16).astype(np.float32) / 32768.0
state = EndOfTurnState.INCOMPLETE
if is_speech:
self._silence_frames = 0
self._speech_triggered = True
if self._speech_start_time is None:
self._speech_start_time = time.time()
self._audio_buffer.append((time.time(), audio_float32))
else:
if self._speech_triggered:
self._audio_buffer.append((time.time(), audio_float32))
self._silence_frames += 1
if self._silence_frames * self._chunk_size_ms >= STOP_MS:
logger.debug("End of Turn complete due to STOP_MS.")
state = EndOfTurnState.COMPLETE
self._clear()
else:
# Keep buffering some silence before potential speech starts
self._audio_buffer.append((time.time(), audio_float32))
# Keep the buffer size reasonable, assuming CHUNK is small
max_buffer_time = (
PRE_SPEECH_MS + STOP_MS
) / 1000 + MAX_DURATION_SECONDS # Some extra buffer
while (
self._audio_buffer and self._audio_buffer[0][0] < time.time() - max_buffer_time
):
self._audio_buffer.pop(0)
return state
def analyze_end_of_turn(self) -> EndOfTurnState:
logger.debug("Analyzing End of Turn...")
state = self._process_speech_segment(self._audio_buffer)
if state == EndOfTurnState.COMPLETE:
self._clear()
logger.debug(f"End of Turn result: {state}")
return state
def _clear(self):
self._speech_triggered = False
self._audio_buffer = []
self._speech_start_time = None
self._silence_frames = 0
def _process_speech_segment(self, audio_buffer) -> EndOfTurnState:
state = EndOfTurnState.INCOMPLETE
if not audio_buffer:
return state
# Find start and end indices for the segment
start_time = self._speech_start_time - (PRE_SPEECH_MS / 1000)
start_index = 0
for i, (t, _) in enumerate(audio_buffer):
if t >= start_time:
start_index = i
break
end_index = len(audio_buffer) - 1
# Extract the audio segment
segment_audio_chunks = [chunk for _, chunk in audio_buffer[start_index : end_index + 1]]
segment_audio = np.concatenate(segment_audio_chunks)
# Remove (STOP_MS - 200)ms from the end of the segment
samples_to_remove = int((STOP_MS - 200) / 1000 * self.sample_rate)
segment_audio = segment_audio[:-samples_to_remove]
# Limit maximum duration
if len(segment_audio) / self.sample_rate > MAX_DURATION_SECONDS:
segment_audio = segment_audio[: int(MAX_DURATION_SECONDS * self.sample_rate)]
# No resampling needed as both recording and prediction use 16000 Hz
segment_audio_resampled = segment_audio
if len(segment_audio_resampled) > 0:
# Call the new predict_endpoint function with the audio data
start_time = time.perf_counter()
result = self._predict_endpoint(segment_audio_resampled)
state = (
EndOfTurnState.COMPLETE if result["prediction"] == 1 else EndOfTurnState.INCOMPLETE
)
end_time = time.perf_counter()
logger.debug("--------")
logger.debug(f"Prediction: {'Complete' if result['prediction'] == 1 else 'Incomplete'}")
logger.debug(f"Probability of complete: {result['probability']:.4f}")
logger.debug(f"Prediction took {(end_time - start_time) * 1000:.2f}ms seconds")
else:
logger.debug("Captured empty audio segment, skipping prediction.")
return state
@abstractmethod
def analyze_end_of_turn(self) -> EndOfTurnState:
def _predict_endpoint(self, buffer: np.ndarray) -> Dict[str, any]:
pass

View File

@@ -6,13 +6,15 @@
import os
import time
from typing import Dict
import numpy as np
import torch
from loguru import logger
from pipecat.audio.turn.base_turn_analyzer import BaseEndOfTurnAnalyzer, EndOfTurnState
from pipecat.audio.turn.base_turn_analyzer import (
BaseEndOfTurnAnalyzer,
)
try:
import coremltools as ct
@@ -25,12 +27,6 @@ except ModuleNotFoundError as e:
raise Exception(f"Missing module: {e}")
# TODO: we should convert all this to params
STOP_MS = 1000
PRE_SPEECH_MS = 200
MAX_DURATION_SECONDS = 8 # Maximum duration for the smart turn model
class LocalSmartTurnAnalyzer(BaseEndOfTurnAnalyzer):
def __init__(self):
super().__init__()
@@ -63,113 +59,7 @@ class LocalSmartTurnAnalyzer(BaseEndOfTurnAnalyzer):
self._turn_model = ct.models.MLModel(core_ml_model_path)
logger.debug("Loaded Local Smart Turn")
self._audio_buffer = []
self._speech_triggered = False
self._silence_frames = 0
self._speech_start_time = None
def append_audio(self, buffer: bytes, is_speech: bool) -> EndOfTurnState:
audio_int16 = np.frombuffer(buffer, dtype=np.int16)
# Divide by 32768 because we have signed 16-bit data.
audio_float32 = np.frombuffer(audio_int16, dtype=np.int16).astype(np.float32) / 32768.0
state = EndOfTurnState.INCOMPLETE
if is_speech:
self._silence_frames = 0
self._speech_triggered = True
if self._speech_start_time is None:
self._speech_start_time = time.time()
self._audio_buffer.append((time.time(), audio_float32))
else:
if self._speech_triggered:
self._audio_buffer.append((time.time(), audio_float32))
self._silence_frames += 1
if self._silence_frames * self._chunk_size_ms >= STOP_MS:
logger.debug("End of Turn complete due to STOP_MS.")
state = EndOfTurnState.COMPLETE
self._clear()
else:
# Keep buffering some silence before potential speech starts
self._audio_buffer.append((time.time(), audio_float32))
# Keep the buffer size reasonable, assuming CHUNK is small
max_buffer_time = (
PRE_SPEECH_MS + STOP_MS
) / 1000 + MAX_DURATION_SECONDS # Some extra buffer
while (
self._audio_buffer and self._audio_buffer[0][0] < time.time() - max_buffer_time
):
self._audio_buffer.pop(0)
return state
def analyze_end_of_turn(self) -> EndOfTurnState:
logger.debug("Analyzing End of Turn...")
state = self._process_speech_segment(self._audio_buffer)
if state == EndOfTurnState.COMPLETE:
self._clear()
logger.debug(f"End of Turn result: {state}")
return state
def _clear(self):
self._speech_triggered = False
self._audio_buffer = []
self._speech_start_time = None
self._silence_frames = 0
def _process_speech_segment(self, audio_buffer) -> EndOfTurnState:
state = EndOfTurnState.INCOMPLETE
if not audio_buffer:
return state
# Find start and end indices for the segment
start_time = self._speech_start_time - (PRE_SPEECH_MS / 1000)
start_index = 0
for i, (t, _) in enumerate(audio_buffer):
if t >= start_time:
start_index = i
break
end_index = len(audio_buffer) - 1
# Extract the audio segment
segment_audio_chunks = [chunk for _, chunk in audio_buffer[start_index : end_index + 1]]
segment_audio = np.concatenate(segment_audio_chunks)
# Remove (STOP_MS - 200)ms from the end of the segment
samples_to_remove = int((STOP_MS - 200) / 1000 * self.sample_rate)
segment_audio = segment_audio[:-samples_to_remove]
# Limit maximum duration
if len(segment_audio) / self.sample_rate > MAX_DURATION_SECONDS:
segment_audio = segment_audio[: int(MAX_DURATION_SECONDS * self.sample_rate)]
# No resampling needed as both recording and prediction use 16000 Hz
segment_audio_resampled = segment_audio
if len(segment_audio_resampled) > 0:
# Call the new predict_endpoint function with the audio data
start_time = time.perf_counter()
result = self._predict_endpoint(segment_audio_resampled)
state = (
EndOfTurnState.COMPLETE if result["prediction"] == 1 else EndOfTurnState.INCOMPLETE
)
end_time = time.perf_counter()
logger.debug("--------")
logger.debug(f"Prediction: {'Complete' if result['prediction'] == 1 else 'Incomplete'}")
logger.debug(f"Probability of complete: {result['probability']:.4f}")
logger.debug(f"Prediction took {(end_time - start_time) * 1000:.2f}ms seconds")
else:
logger.debug("Captured empty audio segment, skipping prediction.")
return state
def _predict_endpoint(self, audio_array):
def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, any]:
"""
Predict whether an audio segment is complete (turn ended) or incomplete.