| | |
| | """Modified from https://github.com/MichaelFan01/STDC-Seg.""" |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from mmcv.cnn import ConvModule |
| | from mmengine.model import BaseModule, ModuleList, Sequential |
| |
|
| | from mmseg.registry import MODELS |
| | from ..utils import resize |
| | from .bisenetv1 import AttentionRefinementModule |
| |
|
| |
|
| | class STDCModule(BaseModule): |
| | """STDCModule. |
| | |
| | Args: |
| | in_channels (int): The number of input channels. |
| | out_channels (int): The number of output channels before scaling. |
| | stride (int): The number of stride for the first conv layer. |
| | norm_cfg (dict): Config dict for normalization layer. Default: None. |
| | act_cfg (dict): The activation config for conv layers. |
| | num_convs (int): Numbers of conv layers. |
| | fusion_type (str): Type of fusion operation. Default: 'add'. |
| | init_cfg (dict or list[dict], optional): Initialization config dict. |
| | Default: None. |
| | """ |
| |
|
| | def __init__(self, |
| | in_channels, |
| | out_channels, |
| | stride, |
| | norm_cfg=None, |
| | act_cfg=None, |
| | num_convs=4, |
| | fusion_type='add', |
| | init_cfg=None): |
| | super().__init__(init_cfg=init_cfg) |
| | assert num_convs > 1 |
| | assert fusion_type in ['add', 'cat'] |
| | self.stride = stride |
| | self.with_downsample = True if self.stride == 2 else False |
| | self.fusion_type = fusion_type |
| |
|
| | self.layers = ModuleList() |
| | conv_0 = ConvModule( |
| | in_channels, out_channels // 2, kernel_size=1, norm_cfg=norm_cfg) |
| |
|
| | if self.with_downsample: |
| | self.downsample = ConvModule( |
| | out_channels // 2, |
| | out_channels // 2, |
| | kernel_size=3, |
| | stride=2, |
| | padding=1, |
| | groups=out_channels // 2, |
| | norm_cfg=norm_cfg, |
| | act_cfg=None) |
| |
|
| | if self.fusion_type == 'add': |
| | self.layers.append(nn.Sequential(conv_0, self.downsample)) |
| | self.skip = Sequential( |
| | ConvModule( |
| | in_channels, |
| | in_channels, |
| | kernel_size=3, |
| | stride=2, |
| | padding=1, |
| | groups=in_channels, |
| | norm_cfg=norm_cfg, |
| | act_cfg=None), |
| | ConvModule( |
| | in_channels, |
| | out_channels, |
| | 1, |
| | norm_cfg=norm_cfg, |
| | act_cfg=None)) |
| | else: |
| | self.layers.append(conv_0) |
| | self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) |
| | else: |
| | self.layers.append(conv_0) |
| |
|
| | for i in range(1, num_convs): |
| | out_factor = 2**(i + 1) if i != num_convs - 1 else 2**i |
| | self.layers.append( |
| | ConvModule( |
| | out_channels // 2**i, |
| | out_channels // out_factor, |
| | kernel_size=3, |
| | stride=1, |
| | padding=1, |
| | norm_cfg=norm_cfg, |
| | act_cfg=act_cfg)) |
| |
|
| | def forward(self, inputs): |
| | if self.fusion_type == 'add': |
| | out = self.forward_add(inputs) |
| | else: |
| | out = self.forward_cat(inputs) |
| | return out |
| |
|
| | def forward_add(self, inputs): |
| | layer_outputs = [] |
| | x = inputs.clone() |
| | for layer in self.layers: |
| | x = layer(x) |
| | layer_outputs.append(x) |
| | if self.with_downsample: |
| | inputs = self.skip(inputs) |
| |
|
| | return torch.cat(layer_outputs, dim=1) + inputs |
| |
|
| | def forward_cat(self, inputs): |
| | x0 = self.layers[0](inputs) |
| | layer_outputs = [x0] |
| | for i, layer in enumerate(self.layers[1:]): |
| | if i == 0: |
| | if self.with_downsample: |
| | x = layer(self.downsample(x0)) |
| | else: |
| | x = layer(x0) |
| | else: |
| | x = layer(x) |
| | layer_outputs.append(x) |
| | if self.with_downsample: |
| | layer_outputs[0] = self.skip(x0) |
| | return torch.cat(layer_outputs, dim=1) |
| |
|
| |
|
| | class FeatureFusionModule(BaseModule): |
| | """Feature Fusion Module. This module is different from FeatureFusionModule |
| | in BiSeNetV1. It uses two ConvModules in `self.attention` whose inter |
| | channel number is calculated by given `scale_factor`, while |
| | FeatureFusionModule in BiSeNetV1 only uses one ConvModule in |
| | `self.conv_atten`. |
| | |
| | Args: |
| | in_channels (int): The number of input channels. |
| | out_channels (int): The number of output channels. |
| | scale_factor (int): The number of channel scale factor. |
| | Default: 4. |
| | norm_cfg (dict): Config dict for normalization layer. |
| | Default: dict(type='BN'). |
| | act_cfg (dict): The activation config for conv layers. |
| | Default: dict(type='ReLU'). |
| | init_cfg (dict or list[dict], optional): Initialization config dict. |
| | Default: None. |
| | """ |
| |
|
| | def __init__(self, |
| | in_channels, |
| | out_channels, |
| | scale_factor=4, |
| | norm_cfg=dict(type='BN'), |
| | act_cfg=dict(type='ReLU'), |
| | init_cfg=None): |
| | super().__init__(init_cfg=init_cfg) |
| | channels = out_channels // scale_factor |
| | self.conv0 = ConvModule( |
| | in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg) |
| | self.attention = nn.Sequential( |
| | nn.AdaptiveAvgPool2d((1, 1)), |
| | ConvModule( |
| | out_channels, |
| | channels, |
| | 1, |
| | norm_cfg=None, |
| | bias=False, |
| | act_cfg=act_cfg), |
| | ConvModule( |
| | channels, |
| | out_channels, |
| | 1, |
| | norm_cfg=None, |
| | bias=False, |
| | act_cfg=None), nn.Sigmoid()) |
| |
|
| | def forward(self, spatial_inputs, context_inputs): |
| | inputs = torch.cat([spatial_inputs, context_inputs], dim=1) |
| | x = self.conv0(inputs) |
| | attn = self.attention(x) |
| | x_attn = x * attn |
| | return x_attn + x |
| |
|
| |
|
| | @MODELS.register_module() |
| | class STDCNet(BaseModule): |
| | """This backbone is the implementation of `Rethinking BiSeNet For Real-time |
| | Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_. |
| | |
| | Args: |
| | stdc_type (int): The type of backbone structure, |
| | `STDCNet1` and`STDCNet2` denotes two main backbones in paper, |
| | whose FLOPs is 813M and 1446M, respectively. |
| | in_channels (int): The num of input_channels. |
| | channels (tuple[int]): The output channels for each stage. |
| | bottleneck_type (str): The type of STDC Module type, the value must |
| | be 'add' or 'cat'. |
| | norm_cfg (dict): Config dict for normalization layer. |
| | act_cfg (dict): The activation config for conv layers. |
| | num_convs (int): Numbers of conv layer at each STDC Module. |
| | Default: 4. |
| | with_final_conv (bool): Whether add a conv layer at the Module output. |
| | Default: True. |
| | pretrained (str, optional): Model pretrained path. Default: None. |
| | init_cfg (dict or list[dict], optional): Initialization config dict. |
| | Default: None. |
| | |
| | Example: |
| | >>> import torch |
| | >>> stdc_type = 'STDCNet1' |
| | >>> in_channels = 3 |
| | >>> channels = (32, 64, 256, 512, 1024) |
| | >>> bottleneck_type = 'cat' |
| | >>> inputs = torch.rand(1, 3, 1024, 2048) |
| | >>> self = STDCNet(stdc_type, in_channels, |
| | ... channels, bottleneck_type).eval() |
| | >>> outputs = self.forward(inputs) |
| | >>> for i in range(len(outputs)): |
| | ... print(f'outputs[{i}].shape = {outputs[i].shape}') |
| | outputs[0].shape = torch.Size([1, 256, 128, 256]) |
| | outputs[1].shape = torch.Size([1, 512, 64, 128]) |
| | outputs[2].shape = torch.Size([1, 1024, 32, 64]) |
| | """ |
| |
|
| | arch_settings = { |
| | 'STDCNet1': [(2, 1), (2, 1), (2, 1)], |
| | 'STDCNet2': [(2, 1, 1, 1), (2, 1, 1, 1, 1), (2, 1, 1)] |
| | } |
| |
|
| | def __init__(self, |
| | stdc_type, |
| | in_channels, |
| | channels, |
| | bottleneck_type, |
| | norm_cfg, |
| | act_cfg, |
| | num_convs=4, |
| | with_final_conv=False, |
| | pretrained=None, |
| | init_cfg=None): |
| | super().__init__(init_cfg=init_cfg) |
| | assert stdc_type in self.arch_settings, \ |
| | f'invalid structure {stdc_type} for STDCNet.' |
| | assert bottleneck_type in ['add', 'cat'],\ |
| | f'bottleneck_type must be `add` or `cat`, got {bottleneck_type}' |
| |
|
| | assert len(channels) == 5,\ |
| | f'invalid channels length {len(channels)} for STDCNet.' |
| |
|
| | self.in_channels = in_channels |
| | self.channels = channels |
| | self.stage_strides = self.arch_settings[stdc_type] |
| | self.prtrained = pretrained |
| | self.num_convs = num_convs |
| | self.with_final_conv = with_final_conv |
| |
|
| | self.stages = ModuleList([ |
| | ConvModule( |
| | self.in_channels, |
| | self.channels[0], |
| | kernel_size=3, |
| | stride=2, |
| | padding=1, |
| | norm_cfg=norm_cfg, |
| | act_cfg=act_cfg), |
| | ConvModule( |
| | self.channels[0], |
| | self.channels[1], |
| | kernel_size=3, |
| | stride=2, |
| | padding=1, |
| | norm_cfg=norm_cfg, |
| | act_cfg=act_cfg) |
| | ]) |
| | |
| | |
| | |
| | |
| | |
| | |
| | self.num_shallow_features = len(self.stages) |
| |
|
| | for strides in self.stage_strides: |
| | idx = len(self.stages) - 1 |
| | self.stages.append( |
| | self._make_stage(self.channels[idx], self.channels[idx + 1], |
| | strides, norm_cfg, act_cfg, bottleneck_type)) |
| | |
| | |
| | |
| | |
| | if self.with_final_conv: |
| | self.final_conv = ConvModule( |
| | self.channels[-1], |
| | max(1024, self.channels[-1]), |
| | 1, |
| | norm_cfg=norm_cfg, |
| | act_cfg=act_cfg) |
| |
|
| | def _make_stage(self, in_channels, out_channels, strides, norm_cfg, |
| | act_cfg, bottleneck_type): |
| | layers = [] |
| | for i, stride in enumerate(strides): |
| | layers.append( |
| | STDCModule( |
| | in_channels if i == 0 else out_channels, |
| | out_channels, |
| | stride, |
| | norm_cfg, |
| | act_cfg, |
| | num_convs=self.num_convs, |
| | fusion_type=bottleneck_type)) |
| | return Sequential(*layers) |
| |
|
| | def forward(self, x): |
| | outs = [] |
| | for stage in self.stages: |
| | x = stage(x) |
| | outs.append(x) |
| | if self.with_final_conv: |
| | outs[-1] = self.final_conv(outs[-1]) |
| | outs = outs[self.num_shallow_features:] |
| | return tuple(outs) |
| |
|
| |
|
| | @MODELS.register_module() |
| | class STDCContextPathNet(BaseModule): |
| | """STDCNet with Context Path. The `outs` below is a list of three feature |
| | maps from deep to shallow, whose height and width is from small to big, |
| | respectively. The biggest feature map of `outs` is outputted for |
| | `STDCHead`, where Detail Loss would be calculated by Detail Ground-truth. |
| | The other two feature maps are used for Attention Refinement Module, |
| | respectively. Besides, the biggest feature map of `outs` and the last |
| | output of Attention Refinement Module are concatenated for Feature Fusion |
| | Module. Then, this fusion feature map `feat_fuse` would be outputted for |
| | `decode_head`. More details please refer to Figure 4 of original paper. |
| | |
| | Args: |
| | backbone_cfg (dict): Config dict for stdc backbone. |
| | last_in_channels (tuple(int)), The number of channels of last |
| | two feature maps from stdc backbone. Default: (1024, 512). |
| | out_channels (int): The channels of output feature maps. |
| | Default: 128. |
| | ffm_cfg (dict): Config dict for Feature Fusion Module. Default: |
| | `dict(in_channels=512, out_channels=256, scale_factor=4)`. |
| | upsample_mode (str): Algorithm used for upsampling: |
| | ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | |
| | ``'trilinear'``. Default: ``'nearest'``. |
| | align_corners (str): align_corners argument of F.interpolate. It |
| | must be `None` if upsample_mode is ``'nearest'``. Default: None. |
| | norm_cfg (dict): Config dict for normalization layer. |
| | Default: dict(type='BN'). |
| | init_cfg (dict or list[dict], optional): Initialization config dict. |
| | Default: None. |
| | |
| | Return: |
| | outputs (tuple): The tuple of list of output feature map for |
| | auxiliary heads and decoder head. |
| | """ |
| |
|
| | def __init__(self, |
| | backbone_cfg, |
| | last_in_channels=(1024, 512), |
| | out_channels=128, |
| | ffm_cfg=dict( |
| | in_channels=512, out_channels=256, scale_factor=4), |
| | upsample_mode='nearest', |
| | align_corners=None, |
| | norm_cfg=dict(type='BN'), |
| | init_cfg=None): |
| | super().__init__(init_cfg=init_cfg) |
| | self.backbone = MODELS.build(backbone_cfg) |
| | self.arms = ModuleList() |
| | self.convs = ModuleList() |
| | for channels in last_in_channels: |
| | self.arms.append(AttentionRefinementModule(channels, out_channels)) |
| | self.convs.append( |
| | ConvModule( |
| | out_channels, |
| | out_channels, |
| | 3, |
| | padding=1, |
| | norm_cfg=norm_cfg)) |
| | self.conv_avg = ConvModule( |
| | last_in_channels[0], out_channels, 1, norm_cfg=norm_cfg) |
| |
|
| | self.ffm = FeatureFusionModule(**ffm_cfg) |
| |
|
| | self.upsample_mode = upsample_mode |
| | self.align_corners = align_corners |
| |
|
| | def forward(self, x): |
| | outs = list(self.backbone(x)) |
| | avg = F.adaptive_avg_pool2d(outs[-1], 1) |
| | avg_feat = self.conv_avg(avg) |
| |
|
| | feature_up = resize( |
| | avg_feat, |
| | size=outs[-1].shape[2:], |
| | mode=self.upsample_mode, |
| | align_corners=self.align_corners) |
| | arms_out = [] |
| | for i in range(len(self.arms)): |
| | x_arm = self.arms[i](outs[len(outs) - 1 - i]) + feature_up |
| | feature_up = resize( |
| | x_arm, |
| | size=outs[len(outs) - 1 - i - 1].shape[2:], |
| | mode=self.upsample_mode, |
| | align_corners=self.align_corners) |
| | feature_up = self.convs[i](feature_up) |
| | arms_out.append(feature_up) |
| |
|
| | feat_fuse = self.ffm(outs[0], arms_out[1]) |
| |
|
| | |
| | |
| | |
| | |
| | outputs = [outs[0]] + list(arms_out) + [feat_fuse] |
| | return tuple(outputs) |
| |
|