audio: add BaseTurnAnalyzer class
This commit is contained in:
@@ -39,7 +39,7 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection):
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
vad_audio_passthrough=True,
|
||||
end_of_turn_analyzer=SmartTurnAnalyzer(url=remote_smart_turn_url),
|
||||
turn_analyzer=SmartTurnAnalyzer(url=remote_smart_turn_url),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection):
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
vad_audio_passthrough=True,
|
||||
end_of_turn_analyzer=LocalCoreMLSmartTurnAnalyzer(
|
||||
turn_analyzer=LocalCoreMLSmartTurnAnalyzer(
|
||||
smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams()
|
||||
),
|
||||
),
|
||||
|
||||
@@ -5,20 +5,14 @@
|
||||
#
|
||||
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# Enum for end-of-turn detection states
|
||||
class EndOfTurnState(Enum):
|
||||
COMPLETE = 1
|
||||
INCOMPLETE = 2
|
||||
|
||||
from pipecat.audio.turn.base_turn_analyzer import BaseTurnAnalyzer, EndOfTurnState
|
||||
|
||||
# Default timing parameters
|
||||
STOP_SECS = 3
|
||||
@@ -35,14 +29,13 @@ class SmartTurnParams(BaseModel):
|
||||
# use_only_last_vad_segment: bool = USE_ONLY_LAST_VAD_SEGMENT
|
||||
|
||||
|
||||
class BaseSmartTurn(ABC):
|
||||
class BaseSmartTurn(BaseTurnAnalyzer):
|
||||
def __init__(
|
||||
self, *, sample_rate: Optional[int] = None, params: SmartTurnParams = SmartTurnParams()
|
||||
):
|
||||
self._init_sample_rate = sample_rate
|
||||
super().__init__(sample_rate=sample_rate)
|
||||
self._params = params
|
||||
# Configuration
|
||||
self._sample_rate = 0
|
||||
self._stop_ms = self._params.stop_secs * 1000 # silence threshold in ms
|
||||
# Inference state
|
||||
self._audio_buffer = []
|
||||
@@ -50,13 +43,6 @@ class BaseSmartTurn(ABC):
|
||||
self._silence_ms = 0
|
||||
self._speech_start_time = None
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
return self._sample_rate
|
||||
|
||||
def set_sample_rate(self, sample_rate: int):
|
||||
self._sample_rate = sample_rate
|
||||
|
||||
@property
|
||||
def speech_triggered(self) -> bool:
|
||||
return self._speech_triggered
|
||||
|
||||
81
src/pipecat/audio/turn/base_turn_analyzer.py
Normal file
81
src/pipecat/audio/turn/base_turn_analyzer.py
Normal file
@@ -0,0 +1,81 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class EndOfTurnState(Enum):
|
||||
COMPLETE = 1
|
||||
INCOMPLETE = 2
|
||||
|
||||
|
||||
class BaseTurnAnalyzer(ABC):
|
||||
"""
|
||||
Abstract base class for analyzing user end of turn.
|
||||
"""
|
||||
|
||||
def __init__(self, *, sample_rate: Optional[int] = None):
|
||||
self._init_sample_rate = sample_rate
|
||||
self._sample_rate = 0
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
"""
|
||||
Returns the current sample rate.
|
||||
|
||||
Returns:
|
||||
int: The effective sample rate for audio processing.
|
||||
"""
|
||||
return self._sample_rate
|
||||
|
||||
def set_sample_rate(self, sample_rate: int):
|
||||
"""
|
||||
Sets the sample rate for audio processing.
|
||||
|
||||
If the initial sample rate was provided, it will use that; otherwise, it sets to
|
||||
the provided sample rate.
|
||||
|
||||
Args:
|
||||
sample_rate (int): The sample rate to set.
|
||||
"""
|
||||
self._sample_rate = self._init_sample_rate or sample_rate
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def speech_triggered(self) -> bool:
|
||||
"""
|
||||
Determines if speech has been detected.
|
||||
|
||||
Returns:
|
||||
bool: True if speech is triggered, otherwise False.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def append_audio(self, buffer: bytes, is_speech: bool) -> EndOfTurnState:
|
||||
"""
|
||||
Appends audio data for analysis.
|
||||
|
||||
Args:
|
||||
buffer (bytes): The audio data to append.
|
||||
is_speech (bool): Indicates whether the appended audio is speech or not.
|
||||
|
||||
Returns:
|
||||
EndOfTurnState: The resulting state after appending the audio.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def analyze_end_of_turn(self) -> EndOfTurnState:
|
||||
"""
|
||||
Analyzes if an end of turn has occurred based on the audio input.
|
||||
|
||||
Returns:
|
||||
EndOfTurnState: The result of the end of turn analysis.
|
||||
"""
|
||||
pass
|
||||
@@ -10,7 +10,7 @@ from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.base_smart_turn import BaseSmartTurn, EndOfTurnState
|
||||
from pipecat.audio.turn.base_turn_analyzer import BaseTurnAnalyzer, EndOfTurnState
|
||||
from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADState
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
@@ -66,8 +66,8 @@ class BaseInputTransport(FrameProcessor):
|
||||
return self._params.vad_analyzer
|
||||
|
||||
@property
|
||||
def end_of_turn_analyzer(self) -> Optional[BaseSmartTurn]:
|
||||
return self._params.end_of_turn_analyzer
|
||||
def turn_analyzer(self) -> Optional[BaseTurnAnalyzer]:
|
||||
return self._params.turn_analyzer
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
self._sample_rate = self._params.audio_in_sample_rate or frame.audio_in_sample_rate
|
||||
@@ -76,8 +76,8 @@ class BaseInputTransport(FrameProcessor):
|
||||
if self._params.vad_enabled and self._params.vad_analyzer:
|
||||
self._params.vad_analyzer.set_sample_rate(self._sample_rate)
|
||||
# Configure End of turn analyzer.
|
||||
if self._params.end_of_turn_analyzer:
|
||||
self._params.end_of_turn_analyzer.set_sample_rate(self._sample_rate)
|
||||
if self._params.turn_analyzer:
|
||||
self._params.turn_analyzer.set_sample_rate(self._sample_rate)
|
||||
# Start audio filter.
|
||||
if self._params.audio_in_filter:
|
||||
await self._params.audio_in_filter.start(self._sample_rate)
|
||||
@@ -199,8 +199,8 @@ class BaseInputTransport(FrameProcessor):
|
||||
# - Creating the UserStoppedSpeakingFrame
|
||||
# - Creating the UserStartedSpeakingFrame multiple times
|
||||
can_create_user_frames = (
|
||||
self._params.end_of_turn_analyzer is None
|
||||
or not self._params.end_of_turn_analyzer.speech_triggered
|
||||
self._params.turn_analyzer is None
|
||||
or not self._params.turn_analyzer.speech_triggered
|
||||
)
|
||||
if can_create_user_frames:
|
||||
if new_vad_state == VADState.SPEAKING:
|
||||
@@ -215,9 +215,9 @@ class BaseInputTransport(FrameProcessor):
|
||||
return vad_state
|
||||
|
||||
async def _handle_end_of_turn(self):
|
||||
if self.end_of_turn_analyzer:
|
||||
if self.turn_analyzer:
|
||||
state = await self.get_event_loop().run_in_executor(
|
||||
self._executor, self.end_of_turn_analyzer.analyze_end_of_turn
|
||||
self._executor, self.turn_analyzer.analyze_end_of_turn
|
||||
)
|
||||
await self._handle_end_of_turn_complete(state)
|
||||
|
||||
@@ -230,7 +230,7 @@ class BaseInputTransport(FrameProcessor):
|
||||
):
|
||||
is_speech = vad_state == VADState.SPEAKING or vad_state == VADState.STARTING
|
||||
# If silence exceeds threshold, we are going to receive EndOfTurnState.COMPLETE
|
||||
end_of_turn_state = self._params.end_of_turn_analyzer.append_audio(frame.audio, is_speech)
|
||||
end_of_turn_state = self._params.turn_analyzer.append_audio(frame.audio, is_speech)
|
||||
if end_of_turn_state == EndOfTurnState.COMPLETE:
|
||||
await self._handle_end_of_turn_complete(end_of_turn_state)
|
||||
# Otherwise we are going to trigger to check if the turn is completed based on the VAD
|
||||
@@ -255,7 +255,7 @@ class BaseInputTransport(FrameProcessor):
|
||||
vad_state = await self._handle_vad(frame, vad_state)
|
||||
audio_passthrough = self._params.vad_audio_passthrough
|
||||
|
||||
if self._params.end_of_turn_analyzer:
|
||||
if self._params.turn_analyzer:
|
||||
await self._run_turn_analyzer(frame, vad_state, previous_vad_state)
|
||||
|
||||
# Push audio downstream if passthrough.
|
||||
|
||||
@@ -11,7 +11,7 @@ from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from pipecat.audio.filters.base_audio_filter import BaseAudioFilter
|
||||
from pipecat.audio.mixers.base_audio_mixer import BaseAudioMixer
|
||||
from pipecat.audio.turn.base_smart_turn import BaseSmartTurn
|
||||
from pipecat.audio.turn.base_turn_analyzer import BaseTurnAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADAnalyzer
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
@@ -42,7 +42,7 @@ class TransportParams(BaseModel):
|
||||
vad_enabled: bool = False
|
||||
vad_audio_passthrough: bool = False
|
||||
vad_analyzer: Optional[VADAnalyzer] = None
|
||||
end_of_turn_analyzer: Optional[BaseSmartTurn] = None
|
||||
turn_analyzer: Optional[BaseTurnAnalyzer] = None
|
||||
|
||||
|
||||
class BaseTransport(BaseObject):
|
||||
|
||||
Reference in New Issue
Block a user