| | """ |
| | This code was mostly taken from backbone-unet by mkisantal: |
| | https://github.com/mkisantal/backboned-unet/blob/master/backboned_unet/unet.py |
| | """ |
| | import torch |
| | import torch.nn as nn |
| | from torchvision import models |
| | from torch.nn import functional as F |
| |
|
| | import torch.nn as nn |
| | import torch |
| | from torchvision import models |
| |
|
| |
|
| | class AdaptiveConcatPool2d(nn.Module): |
| | """ |
| | Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`. |
| | Source: Fastai. This code was taken from the fastai library at url |
| | https://github.com/fastai/fastai/blob/master/fastai/layers.py#L176 |
| | """ |
| |
|
| | def __init__(self, sz=None): |
| | "Output will be 2*sz or 2 if sz is None" |
| | super().__init__() |
| | self.output_size = sz or 1 |
| | self.ap = nn.AdaptiveAvgPool2d(self.output_size) |
| | self.mp = nn.AdaptiveMaxPool2d(self.output_size) |
| |
|
| | def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1) |
| |
|
| |
|
| | class MyNorm(nn.Module): |
| | def __init__(self, num_channels): |
| | super(MyNorm, self).__init__() |
| | self.norm = nn.InstanceNorm2d( |
| | num_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False) |
| |
|
| | def forward(self, x): |
| | x = self.norm(x) |
| | return x |
| |
|
| |
|
| | def resnet_fastai(model, pretrained, url, replace_first_layer=None, replace_maxpool_layer=None, progress=True, map_location=None, **kwargs): |
| | cut = -2 |
| | s = model(pretrained=False, **kwargs) |
| | if replace_maxpool_layer is not None: |
| | s.maxpool = replace_maxpool_layer |
| | if replace_first_layer is not None: |
| | body = nn.Sequential(replace_first_layer, *list(s.children())[1:cut]) |
| | else: |
| | body = nn.Sequential(*list(s.children())[:cut]) |
| |
|
| | if pretrained: |
| | state = torch.hub.load_state_dict_from_url(url, |
| | progress=progress, map_location=map_location) |
| | if replace_first_layer is not None: |
| | for each in list(state.keys()).copy(): |
| | if each.find("0.0.") == 0: |
| | del state[each] |
| | body_tail = nn.Sequential(body) |
| | ret = body_tail.load_state_dict(state, strict=False) |
| | return body |
| |
|
| |
|
| | def get_backbone(name, pretrained=True, map_location=None): |
| | """ Loading backbone, defining names for skip-connections and encoder output. """ |
| |
|
| | first_layer_for_4chn = nn.Conv2d( |
| | 4, 64, kernel_size=7, stride=2, padding=3, bias=False) |
| | max_pool_layer_replace = nn.Conv2d( |
| | 64, 64, kernel_size=3, stride=2, padding=1, bias=False) |
| | |
| | if name == 'resnet18': |
| | backbone = models.resnet18(pretrained=pretrained) |
| | if name == 'resnet18-4': |
| | backbone = models.resnet18(pretrained=pretrained) |
| | backbone.conv1 = first_layer_for_4chn |
| | elif name == 'resnet34': |
| | backbone = models.resnet34(pretrained=pretrained) |
| | elif name == 'resnet50': |
| | backbone = models.resnet50(pretrained=False, norm_layer=MyNorm) |
| | backbone.maxpool = max_pool_layer_replace |
| | elif name == 'resnet101': |
| | backbone = models.resnet101(pretrained=pretrained) |
| | elif name == 'resnet152': |
| | backbone = models.resnet152(pretrained=pretrained) |
| | elif name == 'vgg16': |
| | backbone = models.vgg16_bn(pretrained=pretrained).features |
| | elif name == 'vgg19': |
| | backbone = models.vgg19_bn(pretrained=pretrained).features |
| | elif name == 'resnet18_danbo-4': |
| | backbone = resnet_fastai(models.resnet18, url="https://github.com/RF5/danbooru-pretrained/releases/download/v0.1/resnet18-3f77756f.pth", |
| | pretrained=pretrained, map_location=map_location, norm_layer=MyNorm, replace_first_layer=first_layer_for_4chn) |
| | elif name == 'resnet50_danbo': |
| | backbone = resnet_fastai(models.resnet50, url="https://github.com/RF5/danbooru-pretrained/releases/download/v0.1/resnet50-13306192.pth", |
| | pretrained=pretrained, map_location=map_location, norm_layer=MyNorm, replace_maxpool_layer=max_pool_layer_replace) |
| | elif name == 'densenet121': |
| | backbone = models.densenet121(pretrained=True).features |
| | elif name == 'densenet161': |
| | backbone = models.densenet161(pretrained=True).features |
| | elif name == 'densenet169': |
| | backbone = models.densenet169(pretrained=True).features |
| | elif name == 'densenet201': |
| | backbone = models.densenet201(pretrained=True).features |
| | else: |
| | raise NotImplemented( |
| | '{} backbone model is not implemented so far.'.format(name)) |
| | |
| | |
| | if name.startswith('resnet'): |
| | feature_names = [None, 'relu', 'layer1', 'layer2', 'layer3'] |
| | backbone_output = 'layer4' |
| | elif name == 'vgg16': |
| | |
| | feature_names = ['5', '12', '22', '32', '42'] |
| | backbone_output = '43' |
| | elif name == 'vgg19': |
| | feature_names = ['5', '12', '25', '38', '51'] |
| | backbone_output = '52' |
| | elif name.startswith('densenet'): |
| | feature_names = [None, 'relu0', 'denseblock1', |
| | 'denseblock2', 'denseblock3'] |
| | backbone_output = 'denseblock4' |
| | elif name == 'unet_encoder': |
| | feature_names = ['module1', 'module2', 'module3', 'module4'] |
| | backbone_output = 'module5' |
| | else: |
| | raise NotImplemented( |
| | '{} backbone model is not implemented so far.'.format(name)) |
| | if name.find('_danbo') > 0: |
| | feature_names = [None, '2', '4', '5', '6'] |
| | backbone_output = '7' |
| | return backbone, feature_names, backbone_output |
| |
|
| |
|
| | class UpsampleBlock(nn.Module): |
| |
|
| | |
| | |
| |
|
| | def __init__(self, ch_in, ch_out=None, skip_in=0, use_bn=True, parametric=False): |
| | super(UpsampleBlock, self).__init__() |
| |
|
| | self.parametric = parametric |
| | ch_out = ch_in/2 if ch_out is None else ch_out |
| |
|
| | |
| | if parametric: |
| | |
| | self.up = nn.ConvTranspose2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(4, 4), |
| | stride=2, padding=1, output_padding=0, bias=(not use_bn)) |
| | self.bn1 = MyNorm(ch_out) if use_bn else None |
| | else: |
| | self.up = None |
| | ch_in = ch_in + skip_in |
| | self.conv1 = nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(3, 3), |
| | stride=1, padding=1, bias=(not use_bn)) |
| | self.bn1 = MyNorm(ch_out) if use_bn else None |
| |
|
| | self.relu = nn.ReLU(inplace=True) |
| |
|
| | |
| | conv2_in = ch_out if not parametric else ch_out + skip_in |
| | self.conv2 = nn.Conv2d(in_channels=conv2_in, out_channels=ch_out, kernel_size=(3, 3), |
| | stride=1, padding=1, bias=(not use_bn)) |
| | self.bn2 = MyNorm(ch_out) if use_bn else None |
| |
|
| | def forward(self, x, skip_connection=None): |
| |
|
| | x = self.up(x) if self.parametric else F.interpolate(x, size=None, scale_factor=2, mode='bilinear', |
| | align_corners=None) |
| | if self.parametric: |
| | x = self.bn1(x) if self.bn1 is not None else x |
| | x = self.relu(x) |
| |
|
| | if skip_connection is not None: |
| | x = torch.cat([x, skip_connection], dim=1) |
| |
|
| | if not self.parametric: |
| | x = self.conv1(x) |
| | x = self.bn1(x) if self.bn1 is not None else x |
| | x = self.relu(x) |
| | x = self.conv2(x) |
| | x = self.bn2(x) if self.bn2 is not None else x |
| | x = self.relu(x) |
| |
|
| | return x |
| |
|
| |
|
| | class ResEncUnet(nn.Module): |
| |
|
| | """ U-Net (https://arxiv.org/pdf/1505.04597.pdf) implementation with pre-trained torchvision backbones.""" |
| |
|
| | def __init__(self, |
| | backbone_name, |
| | pretrained=True, |
| | encoder_freeze=False, |
| | classes=21, |
| | decoder_filters=(512, 256, 128, 64, 32), |
| | parametric_upsampling=True, |
| | shortcut_features='default', |
| | decoder_use_instancenorm=True, |
| | map_location=None |
| | ): |
| | super(ResEncUnet, self).__init__() |
| |
|
| | self.backbone_name = backbone_name |
| |
|
| | self.backbone, self.shortcut_features, self.bb_out_name = get_backbone( |
| | backbone_name, pretrained=pretrained, map_location=map_location) |
| | shortcut_chs, bb_out_chs = self.infer_skip_channels() |
| | if shortcut_features != 'default': |
| | self.shortcut_features = shortcut_features |
| |
|
| | |
| | self.upsample_blocks = nn.ModuleList() |
| | |
| | decoder_filters = decoder_filters[:len(self.shortcut_features)] |
| | decoder_filters_in = [bb_out_chs] + list(decoder_filters[:-1]) |
| | num_blocks = len(self.shortcut_features) |
| | for i, [filters_in, filters_out] in enumerate(zip(decoder_filters_in, decoder_filters)): |
| | self.upsample_blocks.append(UpsampleBlock(filters_in, filters_out, |
| | skip_in=shortcut_chs[num_blocks-i-1], |
| | parametric=parametric_upsampling, |
| | use_bn=decoder_use_instancenorm)) |
| | self.final_conv = nn.Conv2d( |
| | decoder_filters[-1], classes, kernel_size=(1, 1)) |
| |
|
| | if encoder_freeze: |
| | self.freeze_encoder() |
| |
|
| | def freeze_encoder(self): |
| | """ Freezing encoder parameters, the newly initialized decoder parameters are remaining trainable. """ |
| |
|
| | for param in self.backbone.parameters(): |
| | param.requires_grad = False |
| |
|
| | def forward(self, *input, ret_parser_out=True): |
| | """ Forward propagation in U-Net. """ |
| |
|
| | x, features = self.forward_backbone(*input) |
| | output_feature = [x] |
| | for skip_name, upsample_block in zip(self.shortcut_features[::-1], self.upsample_blocks): |
| | skip_features = features[skip_name] |
| | if skip_features is not None: |
| | output_feature.append(skip_features) |
| | if ret_parser_out: |
| | x = upsample_block(x, skip_features) |
| | if ret_parser_out: |
| | x = self.final_conv(x) |
| | |
| | else: |
| | x = None |
| |
|
| | return x, output_feature |
| |
|
| | def forward_backbone(self, x): |
| | """ Forward propagation in backbone encoder network. """ |
| |
|
| | features = {None: None} if None in self.shortcut_features else dict() |
| | for name, child in self.backbone.named_children(): |
| | x = child(x) |
| | if name in self.shortcut_features: |
| | features[name] = x |
| | if name == self.bb_out_name: |
| | break |
| |
|
| | return x, features |
| |
|
| | def infer_skip_channels(self): |
| | """ Getting the number of channels at skip connections and at the output of the encoder. """ |
| | if self.backbone_name.find("-4") > 0: |
| | x = torch.zeros(1, 4, 224, 224) |
| | else: |
| | x = torch.zeros(1, 3, 224, 224) |
| | has_fullres_features = self.backbone_name.startswith( |
| | 'vgg') or self.backbone_name == 'unet_encoder' |
| | |
| | channels = [] if has_fullres_features else [0] |
| |
|
| | |
| | for name, child in self.backbone.named_children(): |
| | x = child(x) |
| | if name in self.shortcut_features: |
| | channels.append(x.shape[1]) |
| | if name == self.bb_out_name: |
| | out_channels = x.shape[1] |
| | break |
| | return channels, out_channels |
| |
|