Compare commits
1 Commits
filipi/sma
...
aleix/smar
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9d2cd6d30 |
0
src/pipecat/audio/turn/__init__.py
Normal file
0
src/pipecat/audio/turn/__init__.py
Normal file
32
src/pipecat/audio/turn/base_turn_analyzer.py
Normal file
32
src/pipecat/audio/turn/base_turn_analyzer.py
Normal file
@@ -0,0 +1,32 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class EndOfTurnState(Enum):
|
||||
COMPLETE = 1
|
||||
INCOMPLETE = 2
|
||||
|
||||
|
||||
class BaseEndOfTurnAnalyzer(ABC):
|
||||
def __init__(self, *, sample_rate: Optional[int] = None):
|
||||
self._init_sample_rate = sample_rate
|
||||
self._sample_rate = 0
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
return self._sample_rate
|
||||
|
||||
def set_sample_rate(self, sample_rate: int):
|
||||
self._sample_rate = self._init_sample_rate or sample_rate
|
||||
|
||||
@abstractmethod
|
||||
def analyze_audio(self, buffer: bytes) -> EndOfTurnState:
|
||||
pass
|
||||
83
src/pipecat/audio/turn/smart_turn.py
Normal file
83
src/pipecat/audio/turn/smart_turn.py
Normal file
@@ -0,0 +1,83 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
from transformers import AutoFeatureExtractor, Wav2Vec2BertForSequenceClassification
|
||||
|
||||
from pipecat.audio.turn.base_turn_analyzer import BaseEndOfTurnAnalyzer, EndOfTurnState
|
||||
|
||||
# MODEL_PATH = "model-v1"
|
||||
MODEL_PATH = "pipecat-ai/smart-turn"
|
||||
|
||||
|
||||
class SmartTurnAnalyzer(BaseEndOfTurnAnalyzer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._audio_buffer = bytearray()
|
||||
|
||||
logger.debug("Loading Smart Turn model...")
|
||||
|
||||
# Load model and processor
|
||||
model = Wav2Vec2BertForSequenceClassification.from_pretrained(MODEL_PATH)
|
||||
self._processor = AutoFeatureExtractor.from_pretrained(MODEL_PATH)
|
||||
|
||||
# Set model to evaluation mode and move to GPU if available
|
||||
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self._model = model.to(self._device)
|
||||
self._model.eval()
|
||||
|
||||
logger.debug("Loaded Smart Turn")
|
||||
|
||||
def analyze_audio(self, buffer: bytes) -> EndOfTurnState:
|
||||
self._audio_buffer += buffer
|
||||
if len(self._audio_buffer) < 16000 * 2 * 6:
|
||||
return EndOfTurnState.INCOMPLETE
|
||||
|
||||
audio_int16 = np.frombuffer(self._audio_buffer, dtype=np.int16)
|
||||
|
||||
# Divide by 32768 because we have signed 16-bit data.
|
||||
audio_float32 = np.frombuffer(audio_int16, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
print(audio_float32)
|
||||
|
||||
# Process audio
|
||||
inputs = self._processor(
|
||||
audio_float32,
|
||||
sampling_rate=16000,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=800, # Maximum length as specified in training
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# Move inputs to device
|
||||
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
||||
|
||||
# Run inference
|
||||
with torch.no_grad():
|
||||
outputs = self._model(**inputs)
|
||||
logits = outputs.logits
|
||||
|
||||
# Get probabilities using softmax
|
||||
probabilities = torch.nn.functional.softmax(logits, dim=1)
|
||||
completion_prob = probabilities[0, 1].item() # Probability of class 1 (Complete)
|
||||
|
||||
# Make prediction (1 for Complete, 0 for Incomplete)
|
||||
prediction = 1 if completion_prob > 0.5 else 0
|
||||
|
||||
state = EndOfTurnState.COMPLETE if prediction == 1 else EndOfTurnState.INCOMPLETE
|
||||
|
||||
if state == EndOfTurnState.COMPLETE:
|
||||
self._audio_buffer = bytearray()
|
||||
else:
|
||||
self._audio_buffer = self._audio_buffer[len(buffer) :]
|
||||
|
||||
print("AAAAAAAAAAAA", state)
|
||||
|
||||
return state
|
||||
@@ -583,6 +583,18 @@ class EmulateUserStoppedSpeakingFrame(SystemFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserEndOfTurnFrame(SystemFrame):
|
||||
"""Emitted by VAD to indicate that a user has started speaking. This can be
|
||||
used for interruptions or other times when detecting that someone is
|
||||
speaking is more important than knowing what they're saying (as you will
|
||||
with a TranscriptionFrame)
|
||||
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotInterruptionFrame(SystemFrame):
|
||||
"""Emitted by when the bot should be interrupted. This will mainly cause the
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.base_turn_analyzer import BaseEndOfTurnAnalyzer, EndOfTurnState
|
||||
from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADState
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
@@ -24,6 +25,7 @@ from pipecat.frames.frames import (
|
||||
StartInterruptionFrame,
|
||||
StopInterruptionFrame,
|
||||
SystemFrame,
|
||||
UserEndOfTurnFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
VADParamsUpdateFrame,
|
||||
@@ -64,12 +66,19 @@ class BaseInputTransport(FrameProcessor):
|
||||
def vad_analyzer(self) -> Optional[VADAnalyzer]:
|
||||
return self._params.vad_analyzer
|
||||
|
||||
@property
|
||||
def end_of_turn_analyzer(self) -> Optional[BaseEndOfTurnAnalyzer]:
|
||||
return self._params.end_of_turn_analyzer
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
self._sample_rate = self._params.audio_in_sample_rate or frame.audio_in_sample_rate
|
||||
|
||||
# Configure VAD analyzer.
|
||||
if self._params.vad_enabled and self._params.vad_analyzer:
|
||||
self._params.vad_analyzer.set_sample_rate(self._sample_rate)
|
||||
# Configure End of turn analyzer.
|
||||
if self._params.end_of_turn_analyzer:
|
||||
self._params.end_of_turn_analyzer.set_sample_rate(self._sample_rate)
|
||||
# Start audio filter.
|
||||
if self._params.audio_in_filter:
|
||||
await self._params.audio_in_filter.start(self._sample_rate)
|
||||
@@ -198,8 +207,25 @@ class BaseInputTransport(FrameProcessor):
|
||||
vad_state = new_vad_state
|
||||
return vad_state
|
||||
|
||||
async def _end_of_turn_analyze(self, audio_frame: InputAudioRawFrame) -> EndOfTurnState:
|
||||
state = EndOfTurnState.INCOMPLETE
|
||||
if self.end_of_turn_analyzer:
|
||||
state = await self.get_event_loop().run_in_executor(
|
||||
self._executor, self.end_of_turn_analyzer.analyze_audio, audio_frame.audio
|
||||
)
|
||||
return state
|
||||
|
||||
async def _handle_end_of_turn(
|
||||
self, audio_frame: InputAudioRawFrame, end_of_turn_state: EndOfTurnState
|
||||
):
|
||||
new_eot_state = await self._end_of_turn_analyze(audio_frame)
|
||||
if new_eot_state != end_of_turn_state:
|
||||
await self.push_frame(UserEndOfTurnFrame())
|
||||
return new_eot_state
|
||||
|
||||
async def _audio_task_handler(self):
|
||||
vad_state: VADState = VADState.QUIET
|
||||
end_of_turn_state: EndOfTurnState = EndOfTurnState.INCOMPLETE
|
||||
while True:
|
||||
frame: InputAudioRawFrame = await self._audio_in_queue.get()
|
||||
|
||||
@@ -215,6 +241,9 @@ class BaseInputTransport(FrameProcessor):
|
||||
vad_state = await self._handle_vad(frame, vad_state)
|
||||
audio_passthrough = self._params.vad_audio_passthrough
|
||||
|
||||
if self._params.end_of_turn_analyzer:
|
||||
end_of_turn_state = await self._handle_end_of_turn(frame, end_of_turn_state)
|
||||
|
||||
# Push audio downstream if passthrough.
|
||||
if audio_passthrough:
|
||||
await self.push_frame(frame)
|
||||
|
||||
@@ -11,6 +11,7 @@ from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from pipecat.audio.filters.base_audio_filter import BaseAudioFilter
|
||||
from pipecat.audio.mixers.base_audio_mixer import BaseAudioMixer
|
||||
from pipecat.audio.turn.base_turn_analyzer import BaseEndOfTurnAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADAnalyzer
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
@@ -39,6 +40,7 @@ class TransportParams(BaseModel):
|
||||
vad_enabled: bool = False
|
||||
vad_audio_passthrough: bool = False
|
||||
vad_analyzer: Optional[VADAnalyzer] = None
|
||||
end_of_turn_analyzer: Optional[BaseEndOfTurnAnalyzer] = None
|
||||
|
||||
|
||||
class BaseTransport(BaseObject):
|
||||
|
||||
Reference in New Issue
Block a user