Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| from torchvision.models import ( | |
| efficientnet_v2_s, | |
| mobilenet_v3_large, | |
| resnet101, | |
| swin_v2_b, | |
| ) | |
| import math | |
| NUM_GRADUAL_UNFREEZING_STAGES = 5 | |
| SEED = 123 | |
| ACT_FUNCS = { | |
| "relu": nn.ReLU, | |
| "tanh": nn.Tanh, # Tanh is not used | |
| } | |
| def classification_head(in_features: int, config: dict, flatten=False) -> nn.Sequential: | |
| torch.manual_seed(SEED) | |
| first_linear = nn.Linear(in_features, config["units"], bias=False) | |
| nn.init.kaiming_uniform_(first_linear.weight, nonlinearity=config["activation"]) | |
| head = nn.Sequential( | |
| first_linear, | |
| nn.LayerNorm(config["units"]), | |
| ACT_FUNCS[config["activation"]](), | |
| nn.Dropout(config["dropout"]), | |
| nn.Linear(config["units"], 7), | |
| ) | |
| if flatten: | |
| head.insert(0, nn.Flatten()) | |
| return head | |
| class PretrainedModel(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.unfreezing_stage = 0 | |
| # The layers in forwarding order | |
| self.layers_to_unfreeze: list[nn.Module] = [] | |
| self.model_type: str = config["model_type"] | |
| self.grad_cam_layer: list[nn.Module] = [] | |
| def set_head_trainable(self): | |
| """ | |
| Requires overriding if the classification head is not called | |
| "model.classifier" | |
| """ | |
| self.model.classifier.requires_grad_(True) | |
| def inc_grad_unfreezing(self): | |
| """ | |
| Increments the gradual unfreezing process by unfreezing | |
| the next 100% / NUM_GRADUAL_UNFREEZING_STAGES layers | |
| """ | |
| if self.unfreezing_stage <= NUM_GRADUAL_UNFREEZING_STAGES: | |
| self.unfreezing_stage += 1 | |
| self.set_unfreezing_stage(self.unfreezing_stage) | |
| def set_unfreezing_stage(self, unfreezing_stage: int): | |
| self.unfreezing_stage = unfreezing_stage | |
| if self.unfreezing_stage > NUM_GRADUAL_UNFREEZING_STAGES: | |
| self.unfreezing_stage = NUM_GRADUAL_UNFREEZING_STAGES | |
| self.requires_grad_(True) | |
| return | |
| else: | |
| # Make sure all layers are untrainable before | |
| # setting the trainable layers to be trainable | |
| self.requires_grad_(False) | |
| layer_index = math.ceil( | |
| self.unfreezing_stage | |
| * len(self.layers_to_unfreeze) | |
| / NUM_GRADUAL_UNFREEZING_STAGES | |
| ) | |
| for module in self.layers_to_unfreeze[-layer_index:]: | |
| module.requires_grad_(True) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.model(x) | |
| class EfficientNet(PretrainedModel): | |
| def __init__(self, config: dict): | |
| super().__init__(config) | |
| self.model = efficientnet_v2_s() | |
| in_features = self.model.classifier[1].in_features | |
| self.model.classifier = classification_head(in_features, config) | |
| self.layers_to_unfreeze = [ | |
| self.model.features[i] for i in range(len(self.model.features)) | |
| ] | |
| self.grad_cam_layer = [self.model.features[-1][-1]] | |
| class MobileNet(PretrainedModel): | |
| """ | |
| MobileNet V3 or V4, customized for our transfer learning | |
| V4 paper: | |
| https://arxiv.org/abs/2404.10518 | |
| """ | |
| def __init__(self, config: dict, version: str = "v3"): | |
| super().__init__(config) | |
| # MBNetV4 is in a MBNetV3 object for some reason | |
| if version == "v3": | |
| self.model = mobilenet_v3_large() | |
| in_features = self.model.classifier[0].in_features | |
| self.layers_to_unfreeze = [ | |
| self.model.features[i] for i in range(len(self.model.features)) | |
| ] | |
| self.grad_cam_layer = [self.model.features[-1][-1]] | |
| else: | |
| raise NotImplementedError() | |
| self.model.classifier = classification_head(in_features, config) | |
| class ResNet(PretrainedModel): | |
| def __init__(self, config: dict): | |
| super().__init__(config) | |
| self.model = resnet101() | |
| in_features = self.model.fc.in_features | |
| self.model.fc = classification_head(in_features, config) | |
| self.layers_to_unfreeze = [ | |
| self.model.conv1, | |
| self.model.bn1, | |
| self.model.layer1, | |
| self.model.layer2, | |
| self.model.layer3, | |
| self.model.layer4, | |
| ] | |
| self.grad_cam_layer = [self.model.layer4[-1]] | |
| def set_head_trainable(self): | |
| self.model.fc.requires_grad_(True) | |
| class Swin(PretrainedModel): | |
| def __init__(self, config: dict): | |
| super().__init__(config) | |
| self.model = swin_v2_b() | |
| in_features = self.model.head.in_features | |
| self.model.head = classification_head(in_features, config) | |
| self.layers_to_unfreeze = [ | |
| self.model.features[i] for i in range(len(self.model.features)) | |
| ] + [self.model.norm] | |
| self.grad_cam_layer = [self.model.permute] | |
| def set_head_trainable(self): | |
| self.model.head.requires_grad_(True) | |