audio: add BaseTurnAnalyzer class

This commit is contained in:
Aleix Conchillo Flaqué
2025-04-17 18:13:22 -07:00
parent 004a920920
commit 9eed225aa2
6 changed files with 100 additions and 33 deletions

View File

@@ -39,7 +39,7 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection):
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
vad_audio_passthrough=True,
end_of_turn_analyzer=SmartTurnAnalyzer(url=remote_smart_turn_url),
turn_analyzer=SmartTurnAnalyzer(url=remote_smart_turn_url),
),
)

View File

@@ -55,7 +55,7 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection):
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
vad_audio_passthrough=True,
end_of_turn_analyzer=LocalCoreMLSmartTurnAnalyzer(
turn_analyzer=LocalCoreMLSmartTurnAnalyzer(
smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams()
),
),

View File

@@ -5,20 +5,14 @@
#
import time
from abc import ABC, abstractmethod
from enum import Enum
from abc import abstractmethod
from typing import Dict, Optional
import numpy as np
from loguru import logger
from pydantic import BaseModel
# Enum for end-of-turn detection states
class EndOfTurnState(Enum):
COMPLETE = 1
INCOMPLETE = 2
from pipecat.audio.turn.base_turn_analyzer import BaseTurnAnalyzer, EndOfTurnState
# Default timing parameters
STOP_SECS = 3
@@ -35,14 +29,13 @@ class SmartTurnParams(BaseModel):
# use_only_last_vad_segment: bool = USE_ONLY_LAST_VAD_SEGMENT
class BaseSmartTurn(ABC):
class BaseSmartTurn(BaseTurnAnalyzer):
def __init__(
self, *, sample_rate: Optional[int] = None, params: SmartTurnParams = SmartTurnParams()
):
self._init_sample_rate = sample_rate
super().__init__(sample_rate=sample_rate)
self._params = params
# Configuration
self._sample_rate = 0
self._stop_ms = self._params.stop_secs * 1000 # silence threshold in ms
# Inference state
self._audio_buffer = []
@@ -50,13 +43,6 @@ class BaseSmartTurn(ABC):
self._silence_ms = 0
self._speech_start_time = None
@property
def sample_rate(self) -> int:
return self._sample_rate
def set_sample_rate(self, sample_rate: int):
self._sample_rate = sample_rate
@property
def speech_triggered(self) -> bool:
return self._speech_triggered

View File

@@ -0,0 +1,81 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional
class EndOfTurnState(Enum):
COMPLETE = 1
INCOMPLETE = 2
class BaseTurnAnalyzer(ABC):
"""
Abstract base class for analyzing user end of turn.
"""
def __init__(self, *, sample_rate: Optional[int] = None):
self._init_sample_rate = sample_rate
self._sample_rate = 0
@property
def sample_rate(self) -> int:
"""
Returns the current sample rate.
Returns:
int: The effective sample rate for audio processing.
"""
return self._sample_rate
def set_sample_rate(self, sample_rate: int):
"""
Sets the sample rate for audio processing.
If the initial sample rate was provided, it will use that; otherwise, it sets to
the provided sample rate.
Args:
sample_rate (int): The sample rate to set.
"""
self._sample_rate = self._init_sample_rate or sample_rate
@property
@abstractmethod
def speech_triggered(self) -> bool:
"""
Determines if speech has been detected.
Returns:
bool: True if speech is triggered, otherwise False.
"""
pass
@abstractmethod
def append_audio(self, buffer: bytes, is_speech: bool) -> EndOfTurnState:
"""
Appends audio data for analysis.
Args:
buffer (bytes): The audio data to append.
is_speech (bool): Indicates whether the appended audio is speech or not.
Returns:
EndOfTurnState: The resulting state after appending the audio.
"""
pass
@abstractmethod
def analyze_end_of_turn(self) -> EndOfTurnState:
"""
Analyzes if an end of turn has occurred based on the audio input.
Returns:
EndOfTurnState: The result of the end of turn analysis.
"""
pass

View File

@@ -10,7 +10,7 @@ from typing import Optional
from loguru import logger
from pipecat.audio.turn.base_smart_turn import BaseSmartTurn, EndOfTurnState
from pipecat.audio.turn.base_turn_analyzer import BaseTurnAnalyzer, EndOfTurnState
from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADState
from pipecat.frames.frames import (
BotInterruptionFrame,
@@ -66,8 +66,8 @@ class BaseInputTransport(FrameProcessor):
return self._params.vad_analyzer
@property
def end_of_turn_analyzer(self) -> Optional[BaseSmartTurn]:
return self._params.end_of_turn_analyzer
def turn_analyzer(self) -> Optional[BaseTurnAnalyzer]:
return self._params.turn_analyzer
async def start(self, frame: StartFrame):
self._sample_rate = self._params.audio_in_sample_rate or frame.audio_in_sample_rate
@@ -76,8 +76,8 @@ class BaseInputTransport(FrameProcessor):
if self._params.vad_enabled and self._params.vad_analyzer:
self._params.vad_analyzer.set_sample_rate(self._sample_rate)
# Configure End of turn analyzer.
if self._params.end_of_turn_analyzer:
self._params.end_of_turn_analyzer.set_sample_rate(self._sample_rate)
if self._params.turn_analyzer:
self._params.turn_analyzer.set_sample_rate(self._sample_rate)
# Start audio filter.
if self._params.audio_in_filter:
await self._params.audio_in_filter.start(self._sample_rate)
@@ -199,8 +199,8 @@ class BaseInputTransport(FrameProcessor):
# - Creating the UserStoppedSpeakingFrame
# - Creating the UserStartedSpeakingFrame multiple times
can_create_user_frames = (
self._params.end_of_turn_analyzer is None
or not self._params.end_of_turn_analyzer.speech_triggered
self._params.turn_analyzer is None
or not self._params.turn_analyzer.speech_triggered
)
if can_create_user_frames:
if new_vad_state == VADState.SPEAKING:
@@ -215,9 +215,9 @@ class BaseInputTransport(FrameProcessor):
return vad_state
async def _handle_end_of_turn(self):
if self.end_of_turn_analyzer:
if self.turn_analyzer:
state = await self.get_event_loop().run_in_executor(
self._executor, self.end_of_turn_analyzer.analyze_end_of_turn
self._executor, self.turn_analyzer.analyze_end_of_turn
)
await self._handle_end_of_turn_complete(state)
@@ -230,7 +230,7 @@ class BaseInputTransport(FrameProcessor):
):
is_speech = vad_state == VADState.SPEAKING or vad_state == VADState.STARTING
# If silence exceeds threshold, we are going to receive EndOfTurnState.COMPLETE
end_of_turn_state = self._params.end_of_turn_analyzer.append_audio(frame.audio, is_speech)
end_of_turn_state = self._params.turn_analyzer.append_audio(frame.audio, is_speech)
if end_of_turn_state == EndOfTurnState.COMPLETE:
await self._handle_end_of_turn_complete(end_of_turn_state)
# Otherwise we are going to trigger to check if the turn is completed based on the VAD
@@ -255,7 +255,7 @@ class BaseInputTransport(FrameProcessor):
vad_state = await self._handle_vad(frame, vad_state)
audio_passthrough = self._params.vad_audio_passthrough
if self._params.end_of_turn_analyzer:
if self._params.turn_analyzer:
await self._run_turn_analyzer(frame, vad_state, previous_vad_state)
# Push audio downstream if passthrough.

View File

@@ -11,7 +11,7 @@ from pydantic import BaseModel, ConfigDict
from pipecat.audio.filters.base_audio_filter import BaseAudioFilter
from pipecat.audio.mixers.base_audio_mixer import BaseAudioMixer
from pipecat.audio.turn.base_smart_turn import BaseSmartTurn
from pipecat.audio.turn.base_turn_analyzer import BaseTurnAnalyzer
from pipecat.audio.vad.vad_analyzer import VADAnalyzer
from pipecat.processors.frame_processor import FrameProcessor
from pipecat.utils.base_object import BaseObject
@@ -42,7 +42,7 @@ class TransportParams(BaseModel):
vad_enabled: bool = False
vad_audio_passthrough: bool = False
vad_analyzer: Optional[VADAnalyzer] = None
end_of_turn_analyzer: Optional[BaseSmartTurn] = None
turn_analyzer: Optional[BaseTurnAnalyzer] = None
class BaseTransport(BaseObject):