Fix linter warnings
This commit is contained in:
@@ -19,9 +19,14 @@ from pipecat.audio.turn.smart_turn.base_smart_turn import BaseSmartTurn
|
||||
|
||||
try:
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import Wav2Vec2PreTrainedModel, Wav2Vec2Model, Wav2Vec2Processor, Wav2Vec2Config
|
||||
from torch import nn
|
||||
from transformers import (
|
||||
Wav2Vec2Config,
|
||||
Wav2Vec2Model,
|
||||
Wav2Vec2PreTrainedModel,
|
||||
Wav2Vec2Processor,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
@@ -54,9 +59,7 @@ class LocalSmartTurnAnalyzerV2(BaseSmartTurn):
|
||||
|
||||
logger.debug("Loading Local Smart Turn v2 model...")
|
||||
# Load the pretrained model for sequence classification
|
||||
self._turn_model = Wav2Vec2ForEndpointing.from_pretrained(
|
||||
smart_turn_model_path
|
||||
)
|
||||
self._turn_model = _Wav2Vec2ForEndpointing.from_pretrained(smart_turn_model_path)
|
||||
# Load the corresponding feature extractor for preprocessing audio
|
||||
self._turn_processor = Wav2Vec2Processor.from_pretrained(smart_turn_model_path)
|
||||
# Set device to GPU if available, else CPU
|
||||
@@ -75,7 +78,7 @@ class LocalSmartTurnAnalyzerV2(BaseSmartTurn):
|
||||
truncation=True,
|
||||
max_length=16000 * 16, # 16 seconds at 16kHz
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt"
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# Move inputs to device
|
||||
@@ -96,15 +99,14 @@ class LocalSmartTurnAnalyzerV2(BaseSmartTurn):
|
||||
"probability": probability,
|
||||
}
|
||||
|
||||
class Wav2Vec2ForEndpointing(Wav2Vec2PreTrainedModel):
|
||||
|
||||
class _Wav2Vec2ForEndpointing(Wav2Vec2PreTrainedModel):
|
||||
def __init__(self, config: Wav2Vec2Config):
|
||||
super().__init__(config)
|
||||
self.wav2vec2 = Wav2Vec2Model(config)
|
||||
|
||||
self.pool_attention = nn.Sequential(
|
||||
nn.Linear(config.hidden_size, 256),
|
||||
nn.Tanh(),
|
||||
nn.Linear(256, 1)
|
||||
nn.Linear(config.hidden_size, 256), nn.Tanh(), nn.Linear(256, 1)
|
||||
)
|
||||
|
||||
self.classifier = nn.Sequential(
|
||||
@@ -114,7 +116,7 @@ class Wav2Vec2ForEndpointing(Wav2Vec2PreTrainedModel):
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(256, 64),
|
||||
nn.GELU(),
|
||||
nn.Linear(64, 1)
|
||||
nn.Linear(64, 1),
|
||||
)
|
||||
|
||||
for module in self.classifier:
|
||||
@@ -137,7 +139,7 @@ class Wav2Vec2ForEndpointing(Wav2Vec2PreTrainedModel):
|
||||
raise ValueError("attention_mask must be provided for attention pooling")
|
||||
|
||||
attention_weights = attention_weights + (
|
||||
(1.0 - attention_mask.unsqueeze(-1).to(attention_weights.dtype)) * -1e9
|
||||
(1.0 - attention_mask.unsqueeze(-1).to(attention_weights.dtype)) * -1e9
|
||||
)
|
||||
|
||||
attention_weights = F.softmax(attention_weights, dim=1)
|
||||
@@ -178,7 +180,7 @@ class Wav2Vec2ForEndpointing(Wav2Vec2PreTrainedModel):
|
||||
|
||||
# Add L2 regularization for classifier layers
|
||||
l2_lambda = 0.01
|
||||
l2_reg = torch.tensor(0., device=logits.device)
|
||||
l2_reg = torch.tensor(0.0, device=logits.device)
|
||||
for param in self.classifier.parameters():
|
||||
l2_reg += torch.norm(param)
|
||||
loss += l2_lambda * l2_reg
|
||||
@@ -187,4 +189,4 @@ class Wav2Vec2ForEndpointing(Wav2Vec2PreTrainedModel):
|
||||
return {"loss": loss, "logits": probs}
|
||||
|
||||
probs = torch.sigmoid(logits)
|
||||
return {"logits": probs}
|
||||
return {"logits": probs}
|
||||
|
||||
Reference in New Issue
Block a user