| |
| import torch |
| import torch.nn as nn |
| from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule |
| from mmengine.model import BaseModule |
|
|
| from mmseg.models.decode_heads.psp_head import PPM |
| from mmseg.registry import MODELS |
| from ..utils import InvertedResidual, resize |
|
|
|
|
| class LearningToDownsample(nn.Module): |
| """Learning to downsample module. |
| |
| Args: |
| in_channels (int): Number of input channels. |
| dw_channels (tuple[int]): Number of output channels of the first and |
| the second depthwise conv (dwconv) layers. |
| out_channels (int): Number of output channels of the whole |
| 'learning to downsample' module. |
| conv_cfg (dict | None): Config of conv layers. Default: None |
| norm_cfg (dict | None): Config of norm layers. Default: |
| dict(type='BN') |
| act_cfg (dict): Config of activation layers. Default: |
| dict(type='ReLU') |
| dw_act_cfg (dict): In DepthwiseSeparableConvModule, activation config |
| of depthwise ConvModule. If it is 'default', it will be the same |
| as `act_cfg`. Default: None. |
| """ |
|
|
| def __init__(self, |
| in_channels, |
| dw_channels, |
| out_channels, |
| conv_cfg=None, |
| norm_cfg=dict(type='BN'), |
| act_cfg=dict(type='ReLU'), |
| dw_act_cfg=None): |
| super().__init__() |
| self.conv_cfg = conv_cfg |
| self.norm_cfg = norm_cfg |
| self.act_cfg = act_cfg |
| self.dw_act_cfg = dw_act_cfg |
| dw_channels1 = dw_channels[0] |
| dw_channels2 = dw_channels[1] |
|
|
| self.conv = ConvModule( |
| in_channels, |
| dw_channels1, |
| 3, |
| stride=2, |
| padding=1, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg) |
|
|
| self.dsconv1 = DepthwiseSeparableConvModule( |
| dw_channels1, |
| dw_channels2, |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| norm_cfg=self.norm_cfg, |
| dw_act_cfg=self.dw_act_cfg) |
|
|
| self.dsconv2 = DepthwiseSeparableConvModule( |
| dw_channels2, |
| out_channels, |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| norm_cfg=self.norm_cfg, |
| dw_act_cfg=self.dw_act_cfg) |
|
|
| def forward(self, x): |
| x = self.conv(x) |
| x = self.dsconv1(x) |
| x = self.dsconv2(x) |
| return x |
|
|
|
|
| class GlobalFeatureExtractor(nn.Module): |
| """Global feature extractor module. |
| |
| Args: |
| in_channels (int): Number of input channels of the GFE module. |
| Default: 64 |
| block_channels (tuple[int]): Tuple of ints. Each int specifies the |
| number of output channels of each Inverted Residual module. |
| Default: (64, 96, 128) |
| out_channels(int): Number of output channels of the GFE module. |
| Default: 128 |
| expand_ratio (int): Adjusts number of channels of the hidden layer |
| in InvertedResidual by this amount. |
| Default: 6 |
| num_blocks (tuple[int]): Tuple of ints. Each int specifies the |
| number of times each Inverted Residual module is repeated. |
| The repeated Inverted Residual modules are called a 'group'. |
| Default: (3, 3, 3) |
| strides (tuple[int]): Tuple of ints. Each int specifies |
| the downsampling factor of each 'group'. |
| Default: (2, 2, 1) |
| pool_scales (tuple[int]): Tuple of ints. Each int specifies |
| the parameter required in 'global average pooling' within PPM. |
| Default: (1, 2, 3, 6) |
| conv_cfg (dict | None): Config of conv layers. Default: None |
| norm_cfg (dict | None): Config of norm layers. Default: |
| dict(type='BN') |
| act_cfg (dict): Config of activation layers. Default: |
| dict(type='ReLU') |
| align_corners (bool): align_corners argument of F.interpolate. |
| Default: False |
| """ |
|
|
| def __init__(self, |
| in_channels=64, |
| block_channels=(64, 96, 128), |
| out_channels=128, |
| expand_ratio=6, |
| num_blocks=(3, 3, 3), |
| strides=(2, 2, 1), |
| pool_scales=(1, 2, 3, 6), |
| conv_cfg=None, |
| norm_cfg=dict(type='BN'), |
| act_cfg=dict(type='ReLU'), |
| align_corners=False): |
| super().__init__() |
| self.conv_cfg = conv_cfg |
| self.norm_cfg = norm_cfg |
| self.act_cfg = act_cfg |
| assert len(block_channels) == len(num_blocks) == 3 |
| self.bottleneck1 = self._make_layer(in_channels, block_channels[0], |
| num_blocks[0], strides[0], |
| expand_ratio) |
| self.bottleneck2 = self._make_layer(block_channels[0], |
| block_channels[1], num_blocks[1], |
| strides[1], expand_ratio) |
| self.bottleneck3 = self._make_layer(block_channels[1], |
| block_channels[2], num_blocks[2], |
| strides[2], expand_ratio) |
| self.ppm = PPM( |
| pool_scales, |
| block_channels[2], |
| block_channels[2] // 4, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg, |
| align_corners=align_corners) |
|
|
| self.out = ConvModule( |
| block_channels[2] * 2, |
| out_channels, |
| 3, |
| padding=1, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg) |
|
|
| def _make_layer(self, |
| in_channels, |
| out_channels, |
| blocks, |
| stride=1, |
| expand_ratio=6): |
| layers = [ |
| InvertedResidual( |
| in_channels, |
| out_channels, |
| stride, |
| expand_ratio, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg) |
| ] |
| for i in range(1, blocks): |
| layers.append( |
| InvertedResidual( |
| out_channels, |
| out_channels, |
| 1, |
| expand_ratio, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg)) |
| return nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| x = self.bottleneck1(x) |
| x = self.bottleneck2(x) |
| x = self.bottleneck3(x) |
| x = torch.cat([x, *self.ppm(x)], dim=1) |
| x = self.out(x) |
| return x |
|
|
|
|
| class FeatureFusionModule(nn.Module): |
| """Feature fusion module. |
| |
| Args: |
| higher_in_channels (int): Number of input channels of the |
| higher-resolution branch. |
| lower_in_channels (int): Number of input channels of the |
| lower-resolution branch. |
| out_channels (int): Number of output channels. |
| conv_cfg (dict | None): Config of conv layers. Default: None |
| norm_cfg (dict | None): Config of norm layers. Default: |
| dict(type='BN') |
| dwconv_act_cfg (dict): Config of activation layers in 3x3 conv. |
| Default: dict(type='ReLU'). |
| conv_act_cfg (dict): Config of activation layers in the two 1x1 conv. |
| Default: None. |
| align_corners (bool): align_corners argument of F.interpolate. |
| Default: False. |
| """ |
|
|
| def __init__(self, |
| higher_in_channels, |
| lower_in_channels, |
| out_channels, |
| conv_cfg=None, |
| norm_cfg=dict(type='BN'), |
| dwconv_act_cfg=dict(type='ReLU'), |
| conv_act_cfg=None, |
| align_corners=False): |
| super().__init__() |
| self.conv_cfg = conv_cfg |
| self.norm_cfg = norm_cfg |
| self.dwconv_act_cfg = dwconv_act_cfg |
| self.conv_act_cfg = conv_act_cfg |
| self.align_corners = align_corners |
| self.dwconv = ConvModule( |
| lower_in_channels, |
| out_channels, |
| 3, |
| padding=1, |
| groups=out_channels, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.dwconv_act_cfg) |
| self.conv_lower_res = ConvModule( |
| out_channels, |
| out_channels, |
| 1, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.conv_act_cfg) |
|
|
| self.conv_higher_res = ConvModule( |
| higher_in_channels, |
| out_channels, |
| 1, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.conv_act_cfg) |
|
|
| self.relu = nn.ReLU(True) |
|
|
| def forward(self, higher_res_feature, lower_res_feature): |
| lower_res_feature = resize( |
| lower_res_feature, |
| size=higher_res_feature.size()[2:], |
| mode='bilinear', |
| align_corners=self.align_corners) |
| lower_res_feature = self.dwconv(lower_res_feature) |
| lower_res_feature = self.conv_lower_res(lower_res_feature) |
|
|
| higher_res_feature = self.conv_higher_res(higher_res_feature) |
| out = higher_res_feature + lower_res_feature |
| return self.relu(out) |
|
|
|
|
| @MODELS.register_module() |
| class FastSCNN(BaseModule): |
| """Fast-SCNN Backbone. |
| |
| This backbone is the implementation of `Fast-SCNN: Fast Semantic |
| Segmentation Network <https://arxiv.org/abs/1902.04502>`_. |
| |
| Args: |
| in_channels (int): Number of input image channels. Default: 3. |
| downsample_dw_channels (tuple[int]): Number of output channels after |
| the first conv layer & the second conv layer in |
| Learning-To-Downsample (LTD) module. |
| Default: (32, 48). |
| global_in_channels (int): Number of input channels of |
| Global Feature Extractor(GFE). |
| Equal to number of output channels of LTD. |
| Default: 64. |
| global_block_channels (tuple[int]): Tuple of integers that describe |
| the output channels for each of the MobileNet-v2 bottleneck |
| residual blocks in GFE. |
| Default: (64, 96, 128). |
| global_block_strides (tuple[int]): Tuple of integers |
| that describe the strides (downsampling factors) for each of the |
| MobileNet-v2 bottleneck residual blocks in GFE. |
| Default: (2, 2, 1). |
| global_out_channels (int): Number of output channels of GFE. |
| Default: 128. |
| higher_in_channels (int): Number of input channels of the higher |
| resolution branch in FFM. |
| Equal to global_in_channels. |
| Default: 64. |
| lower_in_channels (int): Number of input channels of the lower |
| resolution branch in FFM. |
| Equal to global_out_channels. |
| Default: 128. |
| fusion_out_channels (int): Number of output channels of FFM. |
| Default: 128. |
| out_indices (tuple): Tuple of indices of list |
| [higher_res_features, lower_res_features, fusion_output]. |
| Often set to (0,1,2) to enable aux. heads. |
| Default: (0, 1, 2). |
| conv_cfg (dict | None): Config of conv layers. Default: None |
| norm_cfg (dict | None): Config of norm layers. Default: |
| dict(type='BN') |
| act_cfg (dict): Config of activation layers. Default: |
| dict(type='ReLU') |
| align_corners (bool): align_corners argument of F.interpolate. |
| Default: False |
| dw_act_cfg (dict): In DepthwiseSeparableConvModule, activation config |
| of depthwise ConvModule. If it is 'default', it will be the same |
| as `act_cfg`. Default: None. |
| init_cfg (dict or list[dict], optional): Initialization config dict. |
| Default: None |
| """ |
|
|
| def __init__(self, |
| in_channels=3, |
| downsample_dw_channels=(32, 48), |
| global_in_channels=64, |
| global_block_channels=(64, 96, 128), |
| global_block_strides=(2, 2, 1), |
| global_out_channels=128, |
| higher_in_channels=64, |
| lower_in_channels=128, |
| fusion_out_channels=128, |
| out_indices=(0, 1, 2), |
| conv_cfg=None, |
| norm_cfg=dict(type='BN'), |
| act_cfg=dict(type='ReLU'), |
| align_corners=False, |
| dw_act_cfg=None, |
| init_cfg=None): |
|
|
| super().__init__(init_cfg) |
|
|
| if init_cfg is None: |
| self.init_cfg = [ |
| dict(type='Kaiming', layer='Conv2d'), |
| dict( |
| type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) |
| ] |
|
|
| if global_in_channels != higher_in_channels: |
| raise AssertionError('Global Input Channels must be the same \ |
| with Higher Input Channels!') |
| elif global_out_channels != lower_in_channels: |
| raise AssertionError('Global Output Channels must be the same \ |
| with Lower Input Channels!') |
|
|
| self.in_channels = in_channels |
| self.downsample_dw_channels1 = downsample_dw_channels[0] |
| self.downsample_dw_channels2 = downsample_dw_channels[1] |
| self.global_in_channels = global_in_channels |
| self.global_block_channels = global_block_channels |
| self.global_block_strides = global_block_strides |
| self.global_out_channels = global_out_channels |
| self.higher_in_channels = higher_in_channels |
| self.lower_in_channels = lower_in_channels |
| self.fusion_out_channels = fusion_out_channels |
| self.out_indices = out_indices |
| self.conv_cfg = conv_cfg |
| self.norm_cfg = norm_cfg |
| self.act_cfg = act_cfg |
| self.align_corners = align_corners |
| self.learning_to_downsample = LearningToDownsample( |
| in_channels, |
| downsample_dw_channels, |
| global_in_channels, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg, |
| dw_act_cfg=dw_act_cfg) |
| self.global_feature_extractor = GlobalFeatureExtractor( |
| global_in_channels, |
| global_block_channels, |
| global_out_channels, |
| strides=self.global_block_strides, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg, |
| align_corners=self.align_corners) |
| self.feature_fusion = FeatureFusionModule( |
| higher_in_channels, |
| lower_in_channels, |
| fusion_out_channels, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| dwconv_act_cfg=self.act_cfg, |
| align_corners=self.align_corners) |
|
|
| def forward(self, x): |
| higher_res_features = self.learning_to_downsample(x) |
| lower_res_features = self.global_feature_extractor(higher_res_features) |
| fusion_output = self.feature_fusion(higher_res_features, |
| lower_res_features) |
|
|
| outs = [higher_res_features, lower_res_features, fusion_output] |
| outs = [outs[i] for i in self.out_indices] |
| return tuple(outs) |
|
|