| import numpy as np |
| import torch.nn.functional as F |
| from torch import nn |
| from .model import MLPLayers |
|
|
|
|
| class LinearProbe(nn.Module): |
| def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None): |
| """ |
| Args: |
| model: nn.Module |
| mlp: bool, if True, then use the MLP layer as the linear probe module |
| freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe |
| in_ch: int, the output channel from CLAP model |
| out_ch: int, the output channel from linear probe (class_num) |
| act: torch.nn.functional, the activation function before the loss function |
| """ |
| super().__init__() |
| in_ch = 512 |
| self.clap_model = model |
| self.clap_model.text_branch = None |
| self.freeze = freeze |
| if mlp: |
| self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch]) |
| else: |
| self.lp_layer = nn.Linear(in_ch, out_ch) |
|
|
| if self.freeze: |
| for param in self.clap_model.parameters(): |
| param.requires_grad = False |
|
|
| if act == "None": |
| self.act = None |
| elif act == "relu": |
| self.act = nn.ReLU() |
| elif act == "elu": |
| self.act = nn.ELU() |
| elif act == "prelu": |
| self.act = nn.PReLU(num_parameters=in_ch) |
| elif act == "softmax": |
| self.act = nn.Softmax(dim=-1) |
| elif act == "sigmoid": |
| self.act = nn.Sigmoid() |
|
|
| def forward(self, x, mix_lambda=None, device=None): |
| """ |
| Args: |
| x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list |
| mix_lambda: torch.tensor [batch], the mixup lambda |
| Returns: |
| class_prob: torch.tensor [batch, class_num] |
| |
| """ |
| |
| if self.freeze: |
| self.clap_model.eval() |
|
|
| x = self.clap_model.audio_projection( |
| self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)[ |
| "embedding" |
| ] |
| ) |
| out = self.lp_layer(x) |
| if self.act is not None: |
| out = self.act(out) |
| return out |
|
|