NematodeClassifier / external_models.py
VikramR's picture
Uploaded app
a9d56ef
import torch
from torch import nn
from torchvision.models import (
efficientnet_v2_s,
mobilenet_v3_large,
resnet101,
swin_v2_b,
)
import math
NUM_GRADUAL_UNFREEZING_STAGES = 5
SEED = 123
ACT_FUNCS = {
"relu": nn.ReLU,
"tanh": nn.Tanh, # Tanh is not used
}
def classification_head(in_features: int, config: dict, flatten=False) -> nn.Sequential:
torch.manual_seed(SEED)
first_linear = nn.Linear(in_features, config["units"], bias=False)
nn.init.kaiming_uniform_(first_linear.weight, nonlinearity=config["activation"])
head = nn.Sequential(
first_linear,
nn.LayerNorm(config["units"]),
ACT_FUNCS[config["activation"]](),
nn.Dropout(config["dropout"]),
nn.Linear(config["units"], 7),
)
if flatten:
head.insert(0, nn.Flatten())
return head
class PretrainedModel(nn.Module):
def __init__(self, config):
super().__init__()
self.unfreezing_stage = 0
# The layers in forwarding order
self.layers_to_unfreeze: list[nn.Module] = []
self.model_type: str = config["model_type"]
self.grad_cam_layer: list[nn.Module] = []
def set_head_trainable(self):
"""
Requires overriding if the classification head is not called
"model.classifier"
"""
self.model.classifier.requires_grad_(True)
def inc_grad_unfreezing(self):
"""
Increments the gradual unfreezing process by unfreezing
the next 100% / NUM_GRADUAL_UNFREEZING_STAGES layers
"""
if self.unfreezing_stage <= NUM_GRADUAL_UNFREEZING_STAGES:
self.unfreezing_stage += 1
self.set_unfreezing_stage(self.unfreezing_stage)
def set_unfreezing_stage(self, unfreezing_stage: int):
self.unfreezing_stage = unfreezing_stage
if self.unfreezing_stage > NUM_GRADUAL_UNFREEZING_STAGES:
self.unfreezing_stage = NUM_GRADUAL_UNFREEZING_STAGES
self.requires_grad_(True)
return
else:
# Make sure all layers are untrainable before
# setting the trainable layers to be trainable
self.requires_grad_(False)
layer_index = math.ceil(
self.unfreezing_stage
* len(self.layers_to_unfreeze)
/ NUM_GRADUAL_UNFREEZING_STAGES
)
for module in self.layers_to_unfreeze[-layer_index:]:
module.requires_grad_(True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
class EfficientNet(PretrainedModel):
def __init__(self, config: dict):
super().__init__(config)
self.model = efficientnet_v2_s()
in_features = self.model.classifier[1].in_features
self.model.classifier = classification_head(in_features, config)
self.layers_to_unfreeze = [
self.model.features[i] for i in range(len(self.model.features))
]
self.grad_cam_layer = [self.model.features[-1][-1]]
class MobileNet(PretrainedModel):
"""
MobileNet V3 or V4, customized for our transfer learning
V4 paper:
https://arxiv.org/abs/2404.10518
"""
def __init__(self, config: dict, version: str = "v3"):
super().__init__(config)
# MBNetV4 is in a MBNetV3 object for some reason
if version == "v3":
self.model = mobilenet_v3_large()
in_features = self.model.classifier[0].in_features
self.layers_to_unfreeze = [
self.model.features[i] for i in range(len(self.model.features))
]
self.grad_cam_layer = [self.model.features[-1][-1]]
else:
raise NotImplementedError()
self.model.classifier = classification_head(in_features, config)
class ResNet(PretrainedModel):
def __init__(self, config: dict):
super().__init__(config)
self.model = resnet101()
in_features = self.model.fc.in_features
self.model.fc = classification_head(in_features, config)
self.layers_to_unfreeze = [
self.model.conv1,
self.model.bn1,
self.model.layer1,
self.model.layer2,
self.model.layer3,
self.model.layer4,
]
self.grad_cam_layer = [self.model.layer4[-1]]
def set_head_trainable(self):
self.model.fc.requires_grad_(True)
class Swin(PretrainedModel):
def __init__(self, config: dict):
super().__init__(config)
self.model = swin_v2_b()
in_features = self.model.head.in_features
self.model.head = classification_head(in_features, config)
self.layers_to_unfreeze = [
self.model.features[i] for i in range(len(self.model.features))
] + [self.model.norm]
self.grad_cam_layer = [self.model.permute]
def set_head_trainable(self):
self.model.head.requires_grad_(True)