| import math, torch |
| import torch.nn as nn |
| from transformers import Wav2Vec2Model |
| from huggingface_hub import PyTorchModelHubMixin |
|
|
|
|
| class SEModule(nn.Module): |
| def __init__(self, channels, bottleneck=128): |
| super(SEModule, self).__init__() |
| self.se = nn.Sequential( |
| nn.AdaptiveAvgPool1d(1), |
| nn.Conv1d(channels, bottleneck, kernel_size=1, padding=0), |
| nn.ReLU(), |
| |
| nn.Conv1d(bottleneck, channels, kernel_size=1, padding=0), |
| nn.Sigmoid(), |
| ) |
|
|
| def forward(self, input): |
| x = self.se(input) |
| return input * x |
|
|
|
|
| class Bottle2neck(nn.Module): |
| def __init__(self, inplanes, planes, kernel_size=None, dilation=None, scale=8): |
| super(Bottle2neck, self).__init__() |
| width = int(math.floor(planes / scale)) |
| self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1) |
| self.bn1 = nn.BatchNorm1d(width * scale) |
| self.nums = scale - 1 |
| convs = [] |
| bns = [] |
| num_pad = math.floor(kernel_size / 2) * dilation |
| for i in range(self.nums): |
| convs.append(nn.Conv1d(width, width, kernel_size=kernel_size, dilation=dilation, padding=num_pad)) |
| bns.append(nn.BatchNorm1d(width)) |
| self.convs = nn.ModuleList(convs) |
| self.bns = nn.ModuleList(bns) |
| self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1) |
| self.bn3 = nn.BatchNorm1d(planes) |
| self.relu = nn.ReLU() |
| self.width = width |
| self.se = SEModule(planes) |
|
|
| def forward(self, x): |
| residual = x |
| out = self.conv1(x) |
| out = self.relu(out) |
| out = self.bn1(out) |
|
|
| spx = torch.split(out, self.width, 1) |
| for i in range(self.nums): |
| if i == 0: |
| sp = spx[i] |
| else: |
| sp = sp + spx[i] |
| sp = self.convs[i](sp) |
| sp = self.relu(sp) |
| sp = self.bns[i](sp) |
| if i == 0: |
| out = sp |
| else: |
| out = torch.cat((out, sp), 1) |
| out = torch.cat((out, spx[self.nums]), 1) |
|
|
| out = self.conv3(out) |
| out = self.relu(out) |
| out = self.bn3(out) |
|
|
| out = self.se(out) |
| out += residual |
| return out |
|
|
|
|
| class ECAPA_TDNN(nn.Module): |
|
|
| def __init__(self, C): |
|
|
| super(ECAPA_TDNN, self).__init__() |
| self.conv1 = nn.Conv1d(128, C, kernel_size=5, stride=1, padding=2) |
| self.relu = nn.ReLU() |
| self.bn1 = nn.BatchNorm1d(C) |
| self.layer1 = Bottle2neck(C, C, kernel_size=3, dilation=2, scale=8) |
| self.layer2 = Bottle2neck(C, C, kernel_size=3, dilation=3, scale=8) |
| self.layer3 = Bottle2neck(C, C, kernel_size=3, dilation=4, scale=8) |
| self.layer4 = Bottle2neck(C, C, kernel_size=3, dilation=5, scale=8) |
| |
| self.layer5 = nn.Conv1d(4 * C, 1536, kernel_size=1) |
| self.attention = nn.Sequential( |
| nn.Conv1d(4608, 256, kernel_size=1), |
| nn.ReLU(), |
| nn.BatchNorm1d(256), |
| nn.Tanh(), |
| nn.Conv1d(256, 1536, kernel_size=1), |
| nn.Softmax(dim=2), |
| ) |
| self.bn5 = nn.BatchNorm1d(3072) |
| self.fc6 = nn.Linear(3072, 2) |
|
|
| def forward(self, x): |
| x = x.transpose(1, 2) |
| x = self.conv1(x) |
| x = self.relu(x) |
| x = self.bn1(x) |
|
|
| x1 = self.layer1(x) |
| x2 = self.layer2(x + x1) |
| x3 = self.layer3(x + x1 + x2) |
| x4 = self.layer4(x + x1 + x2 + x3) |
|
|
| x = self.layer5(torch.cat((x1, x2, x3, x4), dim=1)) |
| x = self.relu(x) |
|
|
| t = x.size()[-1] |
|
|
| global_x = torch.cat((x, torch.mean(x, dim=2, keepdim=True).repeat(1, 1, t), torch.sqrt(torch.var(x, dim=2, keepdim=True).clamp(min=1e-4)).repeat(1, 1, t)), dim=1) |
|
|
| w = self.attention(global_x) |
|
|
| mu = torch.sum(x * w, dim=2) |
| sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu ** 2).clamp(min=1e-4)) |
|
|
| x = torch.cat((mu, sg), 1) |
| x = self.bn5(x) |
| x = self.fc6(x) |
|
|
| return x |
|
|
|
|
| class Wav2Vec2Encoder(nn.Module): |
| """SSL encoder based on Hugging Face's Wav2Vec2 model.""" |
|
|
| def __init__(self, |
| model_name_or_path: str = "facebook/wav2vec2-base-960h", |
| output_attentions: bool = False, |
| output_hidden_states: bool = False, |
| normalize_waveform: bool = False): |
| """Initialize the Wav2Vec2 encoder. |
| |
| Args: |
| model_name_or_path: HuggingFace model name or path to local model. |
| output_attentions: Whether to output attentions. |
| output_hidden_states: Whether to output hidden states. |
| normalize_waveform: Whether to normalize the waveform input. |
| """ |
| super().__init__() |
|
|
| self.model_name_or_path = model_name_or_path |
| self.output_attentions = output_attentions |
| self.output_hidden_states = output_hidden_states |
| self.normalize_waveform = normalize_waveform |
|
|
| |
| self.model = Wav2Vec2Model.from_pretrained( |
| model_name_or_path, |
| gradient_checkpointing=False) |
| self.model.config.apply_spec_augment = False |
| self.model.masked_spec_embed = None |
|
|
|
|
| def forward(self, x): |
| """Forward pass through the Wav2Vec2 encoder. |
| |
| Args: |
| x: Input tensor of shape (batch_size, sequence_length, channels) |
| |
| Returns: |
| Extracted features of shape (batch_size, sequence_length, 1024) |
| """ |
| |
| if x.ndim == 3: |
| x = x.squeeze(-1) |
|
|
| |
| if self.normalize_waveform: |
| x = x / (torch.max(torch.abs(x), dim=1, keepdim=True)[0] + 1e-8) |
|
|
| |
| outputs = self.model( |
| x, |
| output_attentions=self.output_attentions, |
| output_hidden_states=self.output_hidden_states, |
| return_dict=True |
| ) |
|
|
| |
| last_hidden_state = outputs.last_hidden_state |
|
|
| return last_hidden_state |
|
|
|
|
| class MLPBridge(nn.Module): |
|
|
| def __init__(self, input_dim: int, output_dim: int, hidden_dim: int = None, |
| dropout: float = 0.1, activation: str = nn.ReLU, n_layers: int = 1): |
| """Initialize the MLP bridge. |
| |
| Args: |
| input_dim: The input dimension from the SSL encoder. |
| output_dim: The output dimension for the model. |
| hidden_dim: Hidden dimension size. If None, use the average of input and output dims. |
| dropout: Dropout probability to apply between layers. |
| activation: Activation function to use |
| n_layers: Number of MLP layers (repeats of Linear+Activation+Dropout blocks). |
| """ |
| super().__init__() |
|
|
| if hidden_dim is None: |
| hidden_dim = (input_dim + output_dim) // 2 |
|
|
| self.input_dim = input_dim |
| self.output_dim = output_dim |
| self.hidden_dim = hidden_dim |
| self.n_layers = n_layers |
|
|
| assert hasattr(activation, 'forward') and callable(getattr(activation, 'forward', None)), "Activation class must have a callable forward() method." |
| act_fn = activation |
|
|
| layers = [] |
| for i in range(n_layers): |
| in_dim = input_dim if i == 0 else hidden_dim |
| out_dim = hidden_dim |
| layers.append(nn.Linear(in_dim, out_dim)) |
| layers.append(act_fn) |
| layers.append(nn.Dropout(dropout) if dropout > 0 else nn.Identity()) |
| |
| layers.append(nn.Linear(hidden_dim, output_dim)) |
| layers.append(nn.Dropout(dropout) if dropout > 0 else nn.Identity()) |
|
|
| self.mlp = nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| """Forward pass through the bridge. |
| |
| Args: |
| x: The input tensor from the SSL encoder. |
| |
| Returns: |
| The transformed tensor. |
| """ |
| return self.mlp(x) |
|
|
|
|
| class Spectra0Model(nn.Module, PyTorchModelHubMixin): |
| def __init__(self, **kwargs): |
| super().__init__() |
| self.ssl_encoder = Wav2Vec2Encoder("facebook/wav2vec2-xls-r-300m") |
| self.bridge = MLPBridge(1024, 128, hidden_dim=128, activation=nn.SELU()) |
| self.ecapa_tdnn = ECAPA_TDNN(128) |
|
|
| def forward(self, x): |
| x = self.ssl_encoder(x) |
| x = self.bridge(x) |
| x = self.ecapa_tdnn(x) |
| return x |
|
|
| @torch.inference_mode() |
| def classify(self, x, threshold: float = -1.0625009): |
| x = self.forward(x)[:, 1] |
| x = (x > threshold).float() |
| return x.item() |
|
|
|
|
| |
| |
| spectra_0 = Spectra0Model |
|
|