diff --git a/src/pipecat/audio/turn/smart_turn/local_smart_turn_v2.py b/src/pipecat/audio/turn/smart_turn/local_smart_turn_v2.py index 07f28c901..a2854a352 100644 --- a/src/pipecat/audio/turn/smart_turn/local_smart_turn_v2.py +++ b/src/pipecat/audio/turn/smart_turn/local_smart_turn_v2.py @@ -19,9 +19,14 @@ from pipecat.audio.turn.smart_turn.base_smart_turn import BaseSmartTurn try: import torch - from torch import nn import torch.nn.functional as F - from transformers import Wav2Vec2PreTrainedModel, Wav2Vec2Model, Wav2Vec2Processor, Wav2Vec2Config + from torch import nn + from transformers import ( + Wav2Vec2Config, + Wav2Vec2Model, + Wav2Vec2PreTrainedModel, + Wav2Vec2Processor, + ) except ModuleNotFoundError as e: logger.error(f"Exception: {e}") logger.error( @@ -54,9 +59,7 @@ class LocalSmartTurnAnalyzerV2(BaseSmartTurn): logger.debug("Loading Local Smart Turn v2 model...") # Load the pretrained model for sequence classification - self._turn_model = Wav2Vec2ForEndpointing.from_pretrained( - smart_turn_model_path - ) + self._turn_model = _Wav2Vec2ForEndpointing.from_pretrained(smart_turn_model_path) # Load the corresponding feature extractor for preprocessing audio self._turn_processor = Wav2Vec2Processor.from_pretrained(smart_turn_model_path) # Set device to GPU if available, else CPU @@ -75,7 +78,7 @@ class LocalSmartTurnAnalyzerV2(BaseSmartTurn): truncation=True, max_length=16000 * 16, # 16 seconds at 16kHz return_attention_mask=True, - return_tensors="pt" + return_tensors="pt", ) # Move inputs to device @@ -96,15 +99,14 @@ class LocalSmartTurnAnalyzerV2(BaseSmartTurn): "probability": probability, } -class Wav2Vec2ForEndpointing(Wav2Vec2PreTrainedModel): + +class _Wav2Vec2ForEndpointing(Wav2Vec2PreTrainedModel): def __init__(self, config: Wav2Vec2Config): super().__init__(config) self.wav2vec2 = Wav2Vec2Model(config) self.pool_attention = nn.Sequential( - nn.Linear(config.hidden_size, 256), - nn.Tanh(), - nn.Linear(256, 1) + nn.Linear(config.hidden_size, 256), nn.Tanh(), nn.Linear(256, 1) ) self.classifier = nn.Sequential( @@ -114,7 +116,7 @@ class Wav2Vec2ForEndpointing(Wav2Vec2PreTrainedModel): nn.Dropout(0.1), nn.Linear(256, 64), nn.GELU(), - nn.Linear(64, 1) + nn.Linear(64, 1), ) for module in self.classifier: @@ -137,7 +139,7 @@ class Wav2Vec2ForEndpointing(Wav2Vec2PreTrainedModel): raise ValueError("attention_mask must be provided for attention pooling") attention_weights = attention_weights + ( - (1.0 - attention_mask.unsqueeze(-1).to(attention_weights.dtype)) * -1e9 + (1.0 - attention_mask.unsqueeze(-1).to(attention_weights.dtype)) * -1e9 ) attention_weights = F.softmax(attention_weights, dim=1) @@ -178,7 +180,7 @@ class Wav2Vec2ForEndpointing(Wav2Vec2PreTrainedModel): # Add L2 regularization for classifier layers l2_lambda = 0.01 - l2_reg = torch.tensor(0., device=logits.device) + l2_reg = torch.tensor(0.0, device=logits.device) for param in self.classifier.parameters(): l2_reg += torch.norm(param) loss += l2_lambda * l2_reg @@ -187,4 +189,4 @@ class Wav2Vec2ForEndpointing(Wav2Vec2PreTrainedModel): return {"loss": loss, "logits": probs} probs = torch.sigmoid(logits) - return {"logits": probs} \ No newline at end of file + return {"logits": probs}