Compare commits

...

1 Commits

Author SHA1 Message Date
Aleix Conchillo Flaqué
e9d2cd6d30 initial smart-turn (end of turn detection) support 2025-03-10 13:43:30 -07:00
6 changed files with 158 additions and 0 deletions

View File

View File

@@ -0,0 +1,32 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional
class EndOfTurnState(Enum):
COMPLETE = 1
INCOMPLETE = 2
class BaseEndOfTurnAnalyzer(ABC):
def __init__(self, *, sample_rate: Optional[int] = None):
self._init_sample_rate = sample_rate
self._sample_rate = 0
@property
def sample_rate(self) -> int:
return self._sample_rate
def set_sample_rate(self, sample_rate: int):
self._sample_rate = self._init_sample_rate or sample_rate
@abstractmethod
def analyze_audio(self, buffer: bytes) -> EndOfTurnState:
pass

View File

@@ -0,0 +1,83 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import numpy as np
import torch
from loguru import logger
from transformers import AutoFeatureExtractor, Wav2Vec2BertForSequenceClassification
from pipecat.audio.turn.base_turn_analyzer import BaseEndOfTurnAnalyzer, EndOfTurnState
# MODEL_PATH = "model-v1"
MODEL_PATH = "pipecat-ai/smart-turn"
class SmartTurnAnalyzer(BaseEndOfTurnAnalyzer):
def __init__(self):
super().__init__()
self._audio_buffer = bytearray()
logger.debug("Loading Smart Turn model...")
# Load model and processor
model = Wav2Vec2BertForSequenceClassification.from_pretrained(MODEL_PATH)
self._processor = AutoFeatureExtractor.from_pretrained(MODEL_PATH)
# Set model to evaluation mode and move to GPU if available
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._model = model.to(self._device)
self._model.eval()
logger.debug("Loaded Smart Turn")
def analyze_audio(self, buffer: bytes) -> EndOfTurnState:
self._audio_buffer += buffer
if len(self._audio_buffer) < 16000 * 2 * 6:
return EndOfTurnState.INCOMPLETE
audio_int16 = np.frombuffer(self._audio_buffer, dtype=np.int16)
# Divide by 32768 because we have signed 16-bit data.
audio_float32 = np.frombuffer(audio_int16, dtype=np.int16).astype(np.float32) / 32768.0
print(audio_float32)
# Process audio
inputs = self._processor(
audio_float32,
sampling_rate=16000,
padding="max_length",
truncation=True,
max_length=800, # Maximum length as specified in training
return_attention_mask=True,
return_tensors="pt",
)
# Move inputs to device
inputs = {k: v.to(self._device) for k, v in inputs.items()}
# Run inference
with torch.no_grad():
outputs = self._model(**inputs)
logits = outputs.logits
# Get probabilities using softmax
probabilities = torch.nn.functional.softmax(logits, dim=1)
completion_prob = probabilities[0, 1].item() # Probability of class 1 (Complete)
# Make prediction (1 for Complete, 0 for Incomplete)
prediction = 1 if completion_prob > 0.5 else 0
state = EndOfTurnState.COMPLETE if prediction == 1 else EndOfTurnState.INCOMPLETE
if state == EndOfTurnState.COMPLETE:
self._audio_buffer = bytearray()
else:
self._audio_buffer = self._audio_buffer[len(buffer) :]
print("AAAAAAAAAAAA", state)
return state

View File

@@ -583,6 +583,18 @@ class EmulateUserStoppedSpeakingFrame(SystemFrame):
pass
@dataclass
class UserEndOfTurnFrame(SystemFrame):
"""Emitted by VAD to indicate that a user has started speaking. This can be
used for interruptions or other times when detecting that someone is
speaking is more important than knowing what they're saying (as you will
with a TranscriptionFrame)
"""
pass
@dataclass
class BotInterruptionFrame(SystemFrame):
"""Emitted by when the bot should be interrupted. This will mainly cause the

View File

@@ -10,6 +10,7 @@ from typing import Optional
from loguru import logger
from pipecat.audio.turn.base_turn_analyzer import BaseEndOfTurnAnalyzer, EndOfTurnState
from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADState
from pipecat.frames.frames import (
BotInterruptionFrame,
@@ -24,6 +25,7 @@ from pipecat.frames.frames import (
StartInterruptionFrame,
StopInterruptionFrame,
SystemFrame,
UserEndOfTurnFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
VADParamsUpdateFrame,
@@ -64,12 +66,19 @@ class BaseInputTransport(FrameProcessor):
def vad_analyzer(self) -> Optional[VADAnalyzer]:
return self._params.vad_analyzer
@property
def end_of_turn_analyzer(self) -> Optional[BaseEndOfTurnAnalyzer]:
return self._params.end_of_turn_analyzer
async def start(self, frame: StartFrame):
self._sample_rate = self._params.audio_in_sample_rate or frame.audio_in_sample_rate
# Configure VAD analyzer.
if self._params.vad_enabled and self._params.vad_analyzer:
self._params.vad_analyzer.set_sample_rate(self._sample_rate)
# Configure End of turn analyzer.
if self._params.end_of_turn_analyzer:
self._params.end_of_turn_analyzer.set_sample_rate(self._sample_rate)
# Start audio filter.
if self._params.audio_in_filter:
await self._params.audio_in_filter.start(self._sample_rate)
@@ -198,8 +207,25 @@ class BaseInputTransport(FrameProcessor):
vad_state = new_vad_state
return vad_state
async def _end_of_turn_analyze(self, audio_frame: InputAudioRawFrame) -> EndOfTurnState:
state = EndOfTurnState.INCOMPLETE
if self.end_of_turn_analyzer:
state = await self.get_event_loop().run_in_executor(
self._executor, self.end_of_turn_analyzer.analyze_audio, audio_frame.audio
)
return state
async def _handle_end_of_turn(
self, audio_frame: InputAudioRawFrame, end_of_turn_state: EndOfTurnState
):
new_eot_state = await self._end_of_turn_analyze(audio_frame)
if new_eot_state != end_of_turn_state:
await self.push_frame(UserEndOfTurnFrame())
return new_eot_state
async def _audio_task_handler(self):
vad_state: VADState = VADState.QUIET
end_of_turn_state: EndOfTurnState = EndOfTurnState.INCOMPLETE
while True:
frame: InputAudioRawFrame = await self._audio_in_queue.get()
@@ -215,6 +241,9 @@ class BaseInputTransport(FrameProcessor):
vad_state = await self._handle_vad(frame, vad_state)
audio_passthrough = self._params.vad_audio_passthrough
if self._params.end_of_turn_analyzer:
end_of_turn_state = await self._handle_end_of_turn(frame, end_of_turn_state)
# Push audio downstream if passthrough.
if audio_passthrough:
await self.push_frame(frame)

View File

@@ -11,6 +11,7 @@ from pydantic import BaseModel, ConfigDict
from pipecat.audio.filters.base_audio_filter import BaseAudioFilter
from pipecat.audio.mixers.base_audio_mixer import BaseAudioMixer
from pipecat.audio.turn.base_turn_analyzer import BaseEndOfTurnAnalyzer
from pipecat.audio.vad.vad_analyzer import VADAnalyzer
from pipecat.processors.frame_processor import FrameProcessor
from pipecat.utils.base_object import BaseObject
@@ -39,6 +40,7 @@ class TransportParams(BaseModel):
vad_enabled: bool = False
vad_audio_passthrough: bool = False
vad_analyzer: Optional[VADAnalyzer] = None
end_of_turn_analyzer: Optional[BaseEndOfTurnAnalyzer] = None
class BaseTransport(BaseObject):