| | import numpy as np |
| | import torch.nn.functional as F |
| | from torch import nn |
| | from .model import MLPLayers |
| |
|
| |
|
| | class LinearProbe(nn.Module): |
| | def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None): |
| | """ |
| | Args: |
| | model: nn.Module |
| | mlp: bool, if True, then use the MLP layer as the linear probe module |
| | freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe |
| | in_ch: int, the output channel from CLAP model |
| | out_ch: int, the output channel from linear probe (class_num) |
| | act: torch.nn.functional, the activation function before the loss function |
| | """ |
| | super().__init__() |
| | in_ch = 512 |
| | self.clap_model = model |
| | self.clap_model.text_branch = None |
| | self.freeze = freeze |
| | if mlp: |
| | self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch]) |
| | else: |
| | self.lp_layer = nn.Linear(in_ch, out_ch) |
| |
|
| | if self.freeze: |
| | for param in self.clap_model.parameters(): |
| | param.requires_grad = False |
| |
|
| | if act == "None": |
| | self.act = None |
| | elif act == "relu": |
| | self.act = nn.ReLU() |
| | elif act == "elu": |
| | self.act = nn.ELU() |
| | elif act == "prelu": |
| | self.act = nn.PReLU(num_parameters=in_ch) |
| | elif act == "softmax": |
| | self.act = nn.Softmax(dim=-1) |
| | elif act == "sigmoid": |
| | self.act = nn.Sigmoid() |
| |
|
| | def forward(self, x, mix_lambda=None, device=None): |
| | """ |
| | Args: |
| | x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list |
| | mix_lambda: torch.tensor [batch], the mixup lambda |
| | Returns: |
| | class_prob: torch.tensor [batch, class_num] |
| | |
| | """ |
| | |
| | if self.freeze: |
| | self.clap_model.eval() |
| |
|
| | x = self.clap_model.audio_projection( |
| | self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)[ |
| | "embedding" |
| | ] |
| | ) |
| | out = self.lp_layer(x) |
| | if self.act is not None: |
| | out = self.act(out) |
| | return out |
| |
|