import torch from torch import nn from torchvision.models import ( efficientnet_v2_s, mobilenet_v3_large, resnet101, swin_v2_b, ) import math NUM_GRADUAL_UNFREEZING_STAGES = 5 SEED = 123 ACT_FUNCS = { "relu": nn.ReLU, "tanh": nn.Tanh, # Tanh is not used } def classification_head(in_features: int, config: dict, flatten=False) -> nn.Sequential: torch.manual_seed(SEED) first_linear = nn.Linear(in_features, config["units"], bias=False) nn.init.kaiming_uniform_(first_linear.weight, nonlinearity=config["activation"]) head = nn.Sequential( first_linear, nn.LayerNorm(config["units"]), ACT_FUNCS[config["activation"]](), nn.Dropout(config["dropout"]), nn.Linear(config["units"], 7), ) if flatten: head.insert(0, nn.Flatten()) return head class PretrainedModel(nn.Module): def __init__(self, config): super().__init__() self.unfreezing_stage = 0 # The layers in forwarding order self.layers_to_unfreeze: list[nn.Module] = [] self.model_type: str = config["model_type"] self.grad_cam_layer: list[nn.Module] = [] def set_head_trainable(self): """ Requires overriding if the classification head is not called "model.classifier" """ self.model.classifier.requires_grad_(True) def inc_grad_unfreezing(self): """ Increments the gradual unfreezing process by unfreezing the next 100% / NUM_GRADUAL_UNFREEZING_STAGES layers """ if self.unfreezing_stage <= NUM_GRADUAL_UNFREEZING_STAGES: self.unfreezing_stage += 1 self.set_unfreezing_stage(self.unfreezing_stage) def set_unfreezing_stage(self, unfreezing_stage: int): self.unfreezing_stage = unfreezing_stage if self.unfreezing_stage > NUM_GRADUAL_UNFREEZING_STAGES: self.unfreezing_stage = NUM_GRADUAL_UNFREEZING_STAGES self.requires_grad_(True) return else: # Make sure all layers are untrainable before # setting the trainable layers to be trainable self.requires_grad_(False) layer_index = math.ceil( self.unfreezing_stage * len(self.layers_to_unfreeze) / NUM_GRADUAL_UNFREEZING_STAGES ) for module in self.layers_to_unfreeze[-layer_index:]: module.requires_grad_(True) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x) class EfficientNet(PretrainedModel): def __init__(self, config: dict): super().__init__(config) self.model = efficientnet_v2_s() in_features = self.model.classifier[1].in_features self.model.classifier = classification_head(in_features, config) self.layers_to_unfreeze = [ self.model.features[i] for i in range(len(self.model.features)) ] self.grad_cam_layer = [self.model.features[-1][-1]] class MobileNet(PretrainedModel): """ MobileNet V3 or V4, customized for our transfer learning V4 paper: https://arxiv.org/abs/2404.10518 """ def __init__(self, config: dict, version: str = "v3"): super().__init__(config) # MBNetV4 is in a MBNetV3 object for some reason if version == "v3": self.model = mobilenet_v3_large() in_features = self.model.classifier[0].in_features self.layers_to_unfreeze = [ self.model.features[i] for i in range(len(self.model.features)) ] self.grad_cam_layer = [self.model.features[-1][-1]] else: raise NotImplementedError() self.model.classifier = classification_head(in_features, config) class ResNet(PretrainedModel): def __init__(self, config: dict): super().__init__(config) self.model = resnet101() in_features = self.model.fc.in_features self.model.fc = classification_head(in_features, config) self.layers_to_unfreeze = [ self.model.conv1, self.model.bn1, self.model.layer1, self.model.layer2, self.model.layer3, self.model.layer4, ] self.grad_cam_layer = [self.model.layer4[-1]] def set_head_trainable(self): self.model.fc.requires_grad_(True) class Swin(PretrainedModel): def __init__(self, config: dict): super().__init__(config) self.model = swin_v2_b() in_features = self.model.head.in_features self.model.head = classification_head(in_features, config) self.layers_to_unfreeze = [ self.model.features[i] for i in range(len(self.model.features)) ] + [self.model.norm] self.grad_cam_layer = [self.model.permute] def set_head_trainable(self): self.model.head.requires_grad_(True)