Fix linter warnings

This commit is contained in:
marcus-daily
2025-07-16 12:35:36 +01:00
committed by Marcus
parent ed8f30ec71
commit b5d1301221

View File

@@ -19,9 +19,14 @@ from pipecat.audio.turn.smart_turn.base_smart_turn import BaseSmartTurn
try:
import torch
from torch import nn
import torch.nn.functional as F
from transformers import Wav2Vec2PreTrainedModel, Wav2Vec2Model, Wav2Vec2Processor, Wav2Vec2Config
from torch import nn
from transformers import (
Wav2Vec2Config,
Wav2Vec2Model,
Wav2Vec2PreTrainedModel,
Wav2Vec2Processor,
)
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
@@ -54,9 +59,7 @@ class LocalSmartTurnAnalyzerV2(BaseSmartTurn):
logger.debug("Loading Local Smart Turn v2 model...")
# Load the pretrained model for sequence classification
self._turn_model = Wav2Vec2ForEndpointing.from_pretrained(
smart_turn_model_path
)
self._turn_model = _Wav2Vec2ForEndpointing.from_pretrained(smart_turn_model_path)
# Load the corresponding feature extractor for preprocessing audio
self._turn_processor = Wav2Vec2Processor.from_pretrained(smart_turn_model_path)
# Set device to GPU if available, else CPU
@@ -75,7 +78,7 @@ class LocalSmartTurnAnalyzerV2(BaseSmartTurn):
truncation=True,
max_length=16000 * 16, # 16 seconds at 16kHz
return_attention_mask=True,
return_tensors="pt"
return_tensors="pt",
)
# Move inputs to device
@@ -96,15 +99,14 @@ class LocalSmartTurnAnalyzerV2(BaseSmartTurn):
"probability": probability,
}
class Wav2Vec2ForEndpointing(Wav2Vec2PreTrainedModel):
class _Wav2Vec2ForEndpointing(Wav2Vec2PreTrainedModel):
def __init__(self, config: Wav2Vec2Config):
super().__init__(config)
self.wav2vec2 = Wav2Vec2Model(config)
self.pool_attention = nn.Sequential(
nn.Linear(config.hidden_size, 256),
nn.Tanh(),
nn.Linear(256, 1)
nn.Linear(config.hidden_size, 256), nn.Tanh(), nn.Linear(256, 1)
)
self.classifier = nn.Sequential(
@@ -114,7 +116,7 @@ class Wav2Vec2ForEndpointing(Wav2Vec2PreTrainedModel):
nn.Dropout(0.1),
nn.Linear(256, 64),
nn.GELU(),
nn.Linear(64, 1)
nn.Linear(64, 1),
)
for module in self.classifier:
@@ -137,7 +139,7 @@ class Wav2Vec2ForEndpointing(Wav2Vec2PreTrainedModel):
raise ValueError("attention_mask must be provided for attention pooling")
attention_weights = attention_weights + (
(1.0 - attention_mask.unsqueeze(-1).to(attention_weights.dtype)) * -1e9
(1.0 - attention_mask.unsqueeze(-1).to(attention_weights.dtype)) * -1e9
)
attention_weights = F.softmax(attention_weights, dim=1)
@@ -178,7 +180,7 @@ class Wav2Vec2ForEndpointing(Wav2Vec2PreTrainedModel):
# Add L2 regularization for classifier layers
l2_lambda = 0.01
l2_reg = torch.tensor(0., device=logits.device)
l2_reg = torch.tensor(0.0, device=logits.device)
for param in self.classifier.parameters():
l2_reg += torch.norm(param)
loss += l2_lambda * l2_reg
@@ -187,4 +189,4 @@ class Wav2Vec2ForEndpointing(Wav2Vec2PreTrainedModel):
return {"loss": loss, "logits": probs}
probs = torch.sigmoid(logits)
return {"logits": probs}
return {"logits": probs}