diff --git a/src/pipecat/audio/turn/base_turn_analyzer.py b/src/pipecat/audio/turn/base_turn_analyzer.py index 76777bc65..bb10c7df3 100644 --- a/src/pipecat/audio/turn/base_turn_analyzer.py +++ b/src/pipecat/audio/turn/base_turn_analyzer.py @@ -5,9 +5,13 @@ # +import time from abc import ABC, abstractmethod from enum import Enum -from typing import Optional +from typing import Dict, Optional + +import numpy as np +from loguru import logger class EndOfTurnState(Enum): @@ -15,11 +19,23 @@ class EndOfTurnState(Enum): INCOMPLETE = 2 +# TODO: we should convert all this to params +STOP_MS = 1000 +PRE_SPEECH_MS = 200 +MAX_DURATION_SECONDS = 8 # Maximum duration for the smart turn model + + class BaseEndOfTurnAnalyzer(ABC): def __init__(self, *, sample_rate: Optional[int] = None): self._init_sample_rate = sample_rate + # settings variables self._sample_rate = 0 self._chunk_size_ms = 0 + # inference variables + self._audio_buffer = [] + self._speech_triggered = False + self._silence_frames = 0 + self._speech_start_time = None @property def sample_rate(self) -> int: @@ -35,10 +51,107 @@ class BaseEndOfTurnAnalyzer(ABC): def set_chunk_size_ms(self, chunk_size_ms: int): self._chunk_size_ms = chunk_size_ms - @abstractmethod def append_audio(self, buffer: bytes, is_speech: bool) -> EndOfTurnState: - pass + audio_int16 = np.frombuffer(buffer, dtype=np.int16) + # Divide by 32768 because we have signed 16-bit data. + audio_float32 = np.frombuffer(audio_int16, dtype=np.int16).astype(np.float32) / 32768.0 + + state = EndOfTurnState.INCOMPLETE + if is_speech: + self._silence_frames = 0 + self._speech_triggered = True + if self._speech_start_time is None: + self._speech_start_time = time.time() + self._audio_buffer.append((time.time(), audio_float32)) + else: + if self._speech_triggered: + self._audio_buffer.append((time.time(), audio_float32)) + self._silence_frames += 1 + if self._silence_frames * self._chunk_size_ms >= STOP_MS: + logger.debug("End of Turn complete due to STOP_MS.") + state = EndOfTurnState.COMPLETE + self._clear() + else: + # Keep buffering some silence before potential speech starts + self._audio_buffer.append((time.time(), audio_float32)) + # Keep the buffer size reasonable, assuming CHUNK is small + max_buffer_time = ( + PRE_SPEECH_MS + STOP_MS + ) / 1000 + MAX_DURATION_SECONDS # Some extra buffer + while ( + self._audio_buffer and self._audio_buffer[0][0] < time.time() - max_buffer_time + ): + self._audio_buffer.pop(0) + + return state + + def analyze_end_of_turn(self) -> EndOfTurnState: + logger.debug("Analyzing End of Turn...") + state = self._process_speech_segment(self._audio_buffer) + if state == EndOfTurnState.COMPLETE: + self._clear() + + logger.debug(f"End of Turn result: {state}") + return state + + def _clear(self): + self._speech_triggered = False + self._audio_buffer = [] + self._speech_start_time = None + self._silence_frames = 0 + + def _process_speech_segment(self, audio_buffer) -> EndOfTurnState: + state = EndOfTurnState.INCOMPLETE + + if not audio_buffer: + return state + + # Find start and end indices for the segment + start_time = self._speech_start_time - (PRE_SPEECH_MS / 1000) + start_index = 0 + for i, (t, _) in enumerate(audio_buffer): + if t >= start_time: + start_index = i + break + + end_index = len(audio_buffer) - 1 + + # Extract the audio segment + segment_audio_chunks = [chunk for _, chunk in audio_buffer[start_index : end_index + 1]] + segment_audio = np.concatenate(segment_audio_chunks) + + # Remove (STOP_MS - 200)ms from the end of the segment + samples_to_remove = int((STOP_MS - 200) / 1000 * self.sample_rate) + segment_audio = segment_audio[:-samples_to_remove] + + # Limit maximum duration + if len(segment_audio) / self.sample_rate > MAX_DURATION_SECONDS: + segment_audio = segment_audio[: int(MAX_DURATION_SECONDS * self.sample_rate)] + + # No resampling needed as both recording and prediction use 16000 Hz + segment_audio_resampled = segment_audio + + if len(segment_audio_resampled) > 0: + # Call the new predict_endpoint function with the audio data + start_time = time.perf_counter() + + result = self._predict_endpoint(segment_audio_resampled) + + state = ( + EndOfTurnState.COMPLETE if result["prediction"] == 1 else EndOfTurnState.INCOMPLETE + ) + + end_time = time.perf_counter() + + logger.debug("--------") + logger.debug(f"Prediction: {'Complete' if result['prediction'] == 1 else 'Incomplete'}") + logger.debug(f"Probability of complete: {result['probability']:.4f}") + logger.debug(f"Prediction took {(end_time - start_time) * 1000:.2f}ms seconds") + else: + logger.debug("Captured empty audio segment, skipping prediction.") + + return state @abstractmethod - def analyze_end_of_turn(self) -> EndOfTurnState: + def _predict_endpoint(self, buffer: np.ndarray) -> Dict[str, any]: pass diff --git a/src/pipecat/audio/turn/local_smart_turn.py b/src/pipecat/audio/turn/local_smart_turn.py index ebc9250a9..efa6c81d1 100644 --- a/src/pipecat/audio/turn/local_smart_turn.py +++ b/src/pipecat/audio/turn/local_smart_turn.py @@ -6,13 +6,15 @@ import os -import time +from typing import Dict import numpy as np import torch from loguru import logger -from pipecat.audio.turn.base_turn_analyzer import BaseEndOfTurnAnalyzer, EndOfTurnState +from pipecat.audio.turn.base_turn_analyzer import ( + BaseEndOfTurnAnalyzer, +) try: import coremltools as ct @@ -25,12 +27,6 @@ except ModuleNotFoundError as e: raise Exception(f"Missing module: {e}") -# TODO: we should convert all this to params -STOP_MS = 1000 -PRE_SPEECH_MS = 200 -MAX_DURATION_SECONDS = 8 # Maximum duration for the smart turn model - - class LocalSmartTurnAnalyzer(BaseEndOfTurnAnalyzer): def __init__(self): super().__init__() @@ -63,113 +59,7 @@ class LocalSmartTurnAnalyzer(BaseEndOfTurnAnalyzer): self._turn_model = ct.models.MLModel(core_ml_model_path) logger.debug("Loaded Local Smart Turn") - self._audio_buffer = [] - self._speech_triggered = False - self._silence_frames = 0 - self._speech_start_time = None - - def append_audio(self, buffer: bytes, is_speech: bool) -> EndOfTurnState: - audio_int16 = np.frombuffer(buffer, dtype=np.int16) - # Divide by 32768 because we have signed 16-bit data. - audio_float32 = np.frombuffer(audio_int16, dtype=np.int16).astype(np.float32) / 32768.0 - - state = EndOfTurnState.INCOMPLETE - if is_speech: - self._silence_frames = 0 - self._speech_triggered = True - if self._speech_start_time is None: - self._speech_start_time = time.time() - self._audio_buffer.append((time.time(), audio_float32)) - else: - if self._speech_triggered: - self._audio_buffer.append((time.time(), audio_float32)) - self._silence_frames += 1 - if self._silence_frames * self._chunk_size_ms >= STOP_MS: - logger.debug("End of Turn complete due to STOP_MS.") - state = EndOfTurnState.COMPLETE - self._clear() - else: - # Keep buffering some silence before potential speech starts - self._audio_buffer.append((time.time(), audio_float32)) - # Keep the buffer size reasonable, assuming CHUNK is small - max_buffer_time = ( - PRE_SPEECH_MS + STOP_MS - ) / 1000 + MAX_DURATION_SECONDS # Some extra buffer - while ( - self._audio_buffer and self._audio_buffer[0][0] < time.time() - max_buffer_time - ): - self._audio_buffer.pop(0) - - return state - - def analyze_end_of_turn(self) -> EndOfTurnState: - logger.debug("Analyzing End of Turn...") - state = self._process_speech_segment(self._audio_buffer) - if state == EndOfTurnState.COMPLETE: - self._clear() - - logger.debug(f"End of Turn result: {state}") - return state - - def _clear(self): - self._speech_triggered = False - self._audio_buffer = [] - self._speech_start_time = None - self._silence_frames = 0 - - def _process_speech_segment(self, audio_buffer) -> EndOfTurnState: - state = EndOfTurnState.INCOMPLETE - - if not audio_buffer: - return state - - # Find start and end indices for the segment - start_time = self._speech_start_time - (PRE_SPEECH_MS / 1000) - start_index = 0 - for i, (t, _) in enumerate(audio_buffer): - if t >= start_time: - start_index = i - break - - end_index = len(audio_buffer) - 1 - - # Extract the audio segment - segment_audio_chunks = [chunk for _, chunk in audio_buffer[start_index : end_index + 1]] - segment_audio = np.concatenate(segment_audio_chunks) - - # Remove (STOP_MS - 200)ms from the end of the segment - samples_to_remove = int((STOP_MS - 200) / 1000 * self.sample_rate) - segment_audio = segment_audio[:-samples_to_remove] - - # Limit maximum duration - if len(segment_audio) / self.sample_rate > MAX_DURATION_SECONDS: - segment_audio = segment_audio[: int(MAX_DURATION_SECONDS * self.sample_rate)] - - # No resampling needed as both recording and prediction use 16000 Hz - segment_audio_resampled = segment_audio - - if len(segment_audio_resampled) > 0: - # Call the new predict_endpoint function with the audio data - start_time = time.perf_counter() - - result = self._predict_endpoint(segment_audio_resampled) - - state = ( - EndOfTurnState.COMPLETE if result["prediction"] == 1 else EndOfTurnState.INCOMPLETE - ) - - end_time = time.perf_counter() - - logger.debug("--------") - logger.debug(f"Prediction: {'Complete' if result['prediction'] == 1 else 'Incomplete'}") - logger.debug(f"Probability of complete: {result['probability']:.4f}") - logger.debug(f"Prediction took {(end_time - start_time) * 1000:.2f}ms seconds") - else: - logger.debug("Captured empty audio segment, skipping prediction.") - - return state - - def _predict_endpoint(self, audio_array): + def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, any]: """ Predict whether an audio segment is complete (turn ended) or incomplete.