From 9eed225aa2071b22d33d41cc7b5bde3fdb80dfef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Thu, 17 Apr 2025 18:13:22 -0700 Subject: [PATCH] audio: add BaseTurnAnalyzer class --- examples/foundational/38-smart-turn.py | 2 +- examples/foundational/38a-local-smart-turn.py | 2 +- src/pipecat/audio/turn/base_smart_turn.py | 22 +---- src/pipecat/audio/turn/base_turn_analyzer.py | 81 +++++++++++++++++++ src/pipecat/transports/base_input.py | 22 ++--- src/pipecat/transports/base_transport.py | 4 +- 6 files changed, 100 insertions(+), 33 deletions(-) create mode 100644 src/pipecat/audio/turn/base_turn_analyzer.py diff --git a/examples/foundational/38-smart-turn.py b/examples/foundational/38-smart-turn.py index 03d530b90..6bace018b 100644 --- a/examples/foundational/38-smart-turn.py +++ b/examples/foundational/38-smart-turn.py @@ -39,7 +39,7 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection): vad_enabled=True, vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)), vad_audio_passthrough=True, - end_of_turn_analyzer=SmartTurnAnalyzer(url=remote_smart_turn_url), + turn_analyzer=SmartTurnAnalyzer(url=remote_smart_turn_url), ), ) diff --git a/examples/foundational/38a-local-smart-turn.py b/examples/foundational/38a-local-smart-turn.py index 7baedf10e..c1260c248 100644 --- a/examples/foundational/38a-local-smart-turn.py +++ b/examples/foundational/38a-local-smart-turn.py @@ -55,7 +55,7 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection): vad_enabled=True, vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)), vad_audio_passthrough=True, - end_of_turn_analyzer=LocalCoreMLSmartTurnAnalyzer( + turn_analyzer=LocalCoreMLSmartTurnAnalyzer( smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams() ), ), diff --git a/src/pipecat/audio/turn/base_smart_turn.py b/src/pipecat/audio/turn/base_smart_turn.py index 0716d4e7c..eab02dab8 100644 --- a/src/pipecat/audio/turn/base_smart_turn.py +++ b/src/pipecat/audio/turn/base_smart_turn.py @@ -5,20 +5,14 @@ # import time -from abc import ABC, abstractmethod -from enum import Enum +from abc import abstractmethod from typing import Dict, Optional import numpy as np from loguru import logger from pydantic import BaseModel - -# Enum for end-of-turn detection states -class EndOfTurnState(Enum): - COMPLETE = 1 - INCOMPLETE = 2 - +from pipecat.audio.turn.base_turn_analyzer import BaseTurnAnalyzer, EndOfTurnState # Default timing parameters STOP_SECS = 3 @@ -35,14 +29,13 @@ class SmartTurnParams(BaseModel): # use_only_last_vad_segment: bool = USE_ONLY_LAST_VAD_SEGMENT -class BaseSmartTurn(ABC): +class BaseSmartTurn(BaseTurnAnalyzer): def __init__( self, *, sample_rate: Optional[int] = None, params: SmartTurnParams = SmartTurnParams() ): - self._init_sample_rate = sample_rate + super().__init__(sample_rate=sample_rate) self._params = params # Configuration - self._sample_rate = 0 self._stop_ms = self._params.stop_secs * 1000 # silence threshold in ms # Inference state self._audio_buffer = [] @@ -50,13 +43,6 @@ class BaseSmartTurn(ABC): self._silence_ms = 0 self._speech_start_time = None - @property - def sample_rate(self) -> int: - return self._sample_rate - - def set_sample_rate(self, sample_rate: int): - self._sample_rate = sample_rate - @property def speech_triggered(self) -> bool: return self._speech_triggered diff --git a/src/pipecat/audio/turn/base_turn_analyzer.py b/src/pipecat/audio/turn/base_turn_analyzer.py new file mode 100644 index 000000000..4d9fc12a6 --- /dev/null +++ b/src/pipecat/audio/turn/base_turn_analyzer.py @@ -0,0 +1,81 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from abc import ABC, abstractmethod +from enum import Enum +from typing import Optional + + +class EndOfTurnState(Enum): + COMPLETE = 1 + INCOMPLETE = 2 + + +class BaseTurnAnalyzer(ABC): + """ + Abstract base class for analyzing user end of turn. + """ + + def __init__(self, *, sample_rate: Optional[int] = None): + self._init_sample_rate = sample_rate + self._sample_rate = 0 + + @property + def sample_rate(self) -> int: + """ + Returns the current sample rate. + + Returns: + int: The effective sample rate for audio processing. + """ + return self._sample_rate + + def set_sample_rate(self, sample_rate: int): + """ + Sets the sample rate for audio processing. + + If the initial sample rate was provided, it will use that; otherwise, it sets to + the provided sample rate. + + Args: + sample_rate (int): The sample rate to set. + """ + self._sample_rate = self._init_sample_rate or sample_rate + + @property + @abstractmethod + def speech_triggered(self) -> bool: + """ + Determines if speech has been detected. + + Returns: + bool: True if speech is triggered, otherwise False. + """ + pass + + @abstractmethod + def append_audio(self, buffer: bytes, is_speech: bool) -> EndOfTurnState: + """ + Appends audio data for analysis. + + Args: + buffer (bytes): The audio data to append. + is_speech (bool): Indicates whether the appended audio is speech or not. + + Returns: + EndOfTurnState: The resulting state after appending the audio. + """ + pass + + @abstractmethod + def analyze_end_of_turn(self) -> EndOfTurnState: + """ + Analyzes if an end of turn has occurred based on the audio input. + + Returns: + EndOfTurnState: The result of the end of turn analysis. + """ + pass diff --git a/src/pipecat/transports/base_input.py b/src/pipecat/transports/base_input.py index 3c7cf868d..55f4e9fa0 100644 --- a/src/pipecat/transports/base_input.py +++ b/src/pipecat/transports/base_input.py @@ -10,7 +10,7 @@ from typing import Optional from loguru import logger -from pipecat.audio.turn.base_smart_turn import BaseSmartTurn, EndOfTurnState +from pipecat.audio.turn.base_turn_analyzer import BaseTurnAnalyzer, EndOfTurnState from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADState from pipecat.frames.frames import ( BotInterruptionFrame, @@ -66,8 +66,8 @@ class BaseInputTransport(FrameProcessor): return self._params.vad_analyzer @property - def end_of_turn_analyzer(self) -> Optional[BaseSmartTurn]: - return self._params.end_of_turn_analyzer + def turn_analyzer(self) -> Optional[BaseTurnAnalyzer]: + return self._params.turn_analyzer async def start(self, frame: StartFrame): self._sample_rate = self._params.audio_in_sample_rate or frame.audio_in_sample_rate @@ -76,8 +76,8 @@ class BaseInputTransport(FrameProcessor): if self._params.vad_enabled and self._params.vad_analyzer: self._params.vad_analyzer.set_sample_rate(self._sample_rate) # Configure End of turn analyzer. - if self._params.end_of_turn_analyzer: - self._params.end_of_turn_analyzer.set_sample_rate(self._sample_rate) + if self._params.turn_analyzer: + self._params.turn_analyzer.set_sample_rate(self._sample_rate) # Start audio filter. if self._params.audio_in_filter: await self._params.audio_in_filter.start(self._sample_rate) @@ -199,8 +199,8 @@ class BaseInputTransport(FrameProcessor): # - Creating the UserStoppedSpeakingFrame # - Creating the UserStartedSpeakingFrame multiple times can_create_user_frames = ( - self._params.end_of_turn_analyzer is None - or not self._params.end_of_turn_analyzer.speech_triggered + self._params.turn_analyzer is None + or not self._params.turn_analyzer.speech_triggered ) if can_create_user_frames: if new_vad_state == VADState.SPEAKING: @@ -215,9 +215,9 @@ class BaseInputTransport(FrameProcessor): return vad_state async def _handle_end_of_turn(self): - if self.end_of_turn_analyzer: + if self.turn_analyzer: state = await self.get_event_loop().run_in_executor( - self._executor, self.end_of_turn_analyzer.analyze_end_of_turn + self._executor, self.turn_analyzer.analyze_end_of_turn ) await self._handle_end_of_turn_complete(state) @@ -230,7 +230,7 @@ class BaseInputTransport(FrameProcessor): ): is_speech = vad_state == VADState.SPEAKING or vad_state == VADState.STARTING # If silence exceeds threshold, we are going to receive EndOfTurnState.COMPLETE - end_of_turn_state = self._params.end_of_turn_analyzer.append_audio(frame.audio, is_speech) + end_of_turn_state = self._params.turn_analyzer.append_audio(frame.audio, is_speech) if end_of_turn_state == EndOfTurnState.COMPLETE: await self._handle_end_of_turn_complete(end_of_turn_state) # Otherwise we are going to trigger to check if the turn is completed based on the VAD @@ -255,7 +255,7 @@ class BaseInputTransport(FrameProcessor): vad_state = await self._handle_vad(frame, vad_state) audio_passthrough = self._params.vad_audio_passthrough - if self._params.end_of_turn_analyzer: + if self._params.turn_analyzer: await self._run_turn_analyzer(frame, vad_state, previous_vad_state) # Push audio downstream if passthrough. diff --git a/src/pipecat/transports/base_transport.py b/src/pipecat/transports/base_transport.py index 79c876fc4..9ef573f7c 100644 --- a/src/pipecat/transports/base_transport.py +++ b/src/pipecat/transports/base_transport.py @@ -11,7 +11,7 @@ from pydantic import BaseModel, ConfigDict from pipecat.audio.filters.base_audio_filter import BaseAudioFilter from pipecat.audio.mixers.base_audio_mixer import BaseAudioMixer -from pipecat.audio.turn.base_smart_turn import BaseSmartTurn +from pipecat.audio.turn.base_turn_analyzer import BaseTurnAnalyzer from pipecat.audio.vad.vad_analyzer import VADAnalyzer from pipecat.processors.frame_processor import FrameProcessor from pipecat.utils.base_object import BaseObject @@ -42,7 +42,7 @@ class TransportParams(BaseModel): vad_enabled: bool = False vad_audio_passthrough: bool = False vad_analyzer: Optional[VADAnalyzer] = None - end_of_turn_analyzer: Optional[BaseSmartTurn] = None + turn_analyzer: Optional[BaseTurnAnalyzer] = None class BaseTransport(BaseObject):