| | |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from mmcv.cnn import ConvModule |
| | from mmengine.model import BaseModule |
| |
|
| | from mmdet.registry import MODELS |
| |
|
| |
|
| | class Transition(BaseModule): |
| | """Base class for transition. |
| | |
| | Args: |
| | in_channels (int): Number of input channels. |
| | out_channels (int): Number of output channels. |
| | """ |
| |
|
| | def __init__(self, in_channels, out_channels, init_cfg=None): |
| | super().__init__(init_cfg) |
| | self.in_channels = in_channels |
| | self.out_channels = out_channels |
| |
|
| | def forward(x): |
| | pass |
| |
|
| |
|
| | class UpInterpolationConv(Transition): |
| | """A transition used for up-sampling. |
| | |
| | Up-sample the input by interpolation then refines the feature by |
| | a convolution layer. |
| | |
| | Args: |
| | in_channels (int): Number of input channels. |
| | out_channels (int): Number of output channels. |
| | scale_factor (int): Up-sampling factor. Default: 2. |
| | mode (int): Interpolation mode. Default: nearest. |
| | align_corners (bool): Whether align corners when interpolation. |
| | Default: None. |
| | kernel_size (int): Kernel size for the conv. Default: 3. |
| | """ |
| |
|
| | def __init__(self, |
| | in_channels, |
| | out_channels, |
| | scale_factor=2, |
| | mode='nearest', |
| | align_corners=None, |
| | kernel_size=3, |
| | init_cfg=None, |
| | **kwargs): |
| | super().__init__(in_channels, out_channels, init_cfg) |
| | self.mode = mode |
| | self.scale_factor = scale_factor |
| | self.align_corners = align_corners |
| | self.conv = ConvModule( |
| | in_channels, |
| | out_channels, |
| | kernel_size, |
| | padding=(kernel_size - 1) // 2, |
| | **kwargs) |
| |
|
| | def forward(self, x): |
| | x = F.interpolate( |
| | x, |
| | scale_factor=self.scale_factor, |
| | mode=self.mode, |
| | align_corners=self.align_corners) |
| | x = self.conv(x) |
| | return x |
| |
|
| |
|
| | class LastConv(Transition): |
| | """A transition used for refining the output of the last stage. |
| | |
| | Args: |
| | in_channels (int): Number of input channels. |
| | out_channels (int): Number of output channels. |
| | num_inputs (int): Number of inputs of the FPN features. |
| | kernel_size (int): Kernel size for the conv. Default: 3. |
| | """ |
| |
|
| | def __init__(self, |
| | in_channels, |
| | out_channels, |
| | num_inputs, |
| | kernel_size=3, |
| | init_cfg=None, |
| | **kwargs): |
| | super().__init__(in_channels, out_channels, init_cfg) |
| | self.num_inputs = num_inputs |
| | self.conv_out = ConvModule( |
| | in_channels, |
| | out_channels, |
| | kernel_size, |
| | padding=(kernel_size - 1) // 2, |
| | **kwargs) |
| |
|
| | def forward(self, inputs): |
| | assert len(inputs) == self.num_inputs |
| | return self.conv_out(inputs[-1]) |
| |
|
| |
|
| | @MODELS.register_module() |
| | class FPG(BaseModule): |
| | """FPG. |
| | |
| | Implementation of `Feature Pyramid Grids (FPG) |
| | <https://arxiv.org/abs/2004.03580>`_. |
| | This implementation only gives the basic structure stated in the paper. |
| | But users can implement different type of transitions to fully explore the |
| | the potential power of the structure of FPG. |
| | |
| | Args: |
| | in_channels (int): Number of input channels (feature maps of all levels |
| | should have the same channels). |
| | out_channels (int): Number of output channels (used at each scale) |
| | num_outs (int): Number of output scales. |
| | stack_times (int): The number of times the pyramid architecture will |
| | be stacked. |
| | paths (list[str]): Specify the path order of each stack level. |
| | Each element in the list should be either 'bu' (bottom-up) or |
| | 'td' (top-down). |
| | inter_channels (int): Number of inter channels. |
| | same_up_trans (dict): Transition that goes down at the same stage. |
| | same_down_trans (dict): Transition that goes up at the same stage. |
| | across_lateral_trans (dict): Across-pathway same-stage |
| | across_down_trans (dict): Across-pathway bottom-up connection. |
| | across_up_trans (dict): Across-pathway top-down connection. |
| | across_skip_trans (dict): Across-pathway skip connection. |
| | output_trans (dict): Transition that trans the output of the |
| | last stage. |
| | start_level (int): Index of the start input backbone level used to |
| | build the feature pyramid. Default: 0. |
| | end_level (int): Index of the end input backbone level (exclusive) to |
| | build the feature pyramid. Default: -1, which means the last level. |
| | add_extra_convs (bool): It decides whether to add conv |
| | layers on top of the original feature maps. Default to False. |
| | If True, its actual mode is specified by `extra_convs_on_inputs`. |
| | norm_cfg (dict): Config dict for normalization layer. Default: None. |
| | init_cfg (dict or list[dict], optional): Initialization config dict. |
| | """ |
| |
|
| | transition_types = { |
| | 'conv': ConvModule, |
| | 'interpolation_conv': UpInterpolationConv, |
| | 'last_conv': LastConv, |
| | } |
| |
|
| | def __init__(self, |
| | in_channels, |
| | out_channels, |
| | num_outs, |
| | stack_times, |
| | paths, |
| | inter_channels=None, |
| | same_down_trans=None, |
| | same_up_trans=dict( |
| | type='conv', kernel_size=3, stride=2, padding=1), |
| | across_lateral_trans=dict(type='conv', kernel_size=1), |
| | across_down_trans=dict(type='conv', kernel_size=3), |
| | across_up_trans=None, |
| | across_skip_trans=dict(type='identity'), |
| | output_trans=dict(type='last_conv', kernel_size=3), |
| | start_level=0, |
| | end_level=-1, |
| | add_extra_convs=False, |
| | norm_cfg=None, |
| | skip_inds=None, |
| | init_cfg=[ |
| | dict(type='Caffe2Xavier', layer='Conv2d'), |
| | dict( |
| | type='Constant', |
| | layer=[ |
| | '_BatchNorm', '_InstanceNorm', 'GroupNorm', |
| | 'LayerNorm' |
| | ], |
| | val=1.0) |
| | ]): |
| | super(FPG, self).__init__(init_cfg) |
| | assert isinstance(in_channels, list) |
| | self.in_channels = in_channels |
| | self.out_channels = out_channels |
| | self.num_ins = len(in_channels) |
| | self.num_outs = num_outs |
| | if inter_channels is None: |
| | self.inter_channels = [out_channels for _ in range(num_outs)] |
| | elif isinstance(inter_channels, int): |
| | self.inter_channels = [inter_channels for _ in range(num_outs)] |
| | else: |
| | assert isinstance(inter_channels, list) |
| | assert len(inter_channels) == num_outs |
| | self.inter_channels = inter_channels |
| | self.stack_times = stack_times |
| | self.paths = paths |
| | assert isinstance(paths, list) and len(paths) == stack_times |
| | for d in paths: |
| | assert d in ('bu', 'td') |
| |
|
| | self.same_down_trans = same_down_trans |
| | self.same_up_trans = same_up_trans |
| | self.across_lateral_trans = across_lateral_trans |
| | self.across_down_trans = across_down_trans |
| | self.across_up_trans = across_up_trans |
| | self.output_trans = output_trans |
| | self.across_skip_trans = across_skip_trans |
| |
|
| | self.with_bias = norm_cfg is None |
| | |
| | if self.across_skip_trans is not None: |
| | skip_inds is not None |
| | self.skip_inds = skip_inds |
| | assert len(self.skip_inds[0]) <= self.stack_times |
| |
|
| | if end_level == -1 or end_level == self.num_ins - 1: |
| | self.backbone_end_level = self.num_ins |
| | assert num_outs >= self.num_ins - start_level |
| | else: |
| | |
| | self.backbone_end_level = end_level + 1 |
| | assert end_level < self.num_ins |
| | assert num_outs == end_level - start_level + 1 |
| | self.start_level = start_level |
| | self.end_level = end_level |
| | self.add_extra_convs = add_extra_convs |
| |
|
| | |
| | self.lateral_convs = nn.ModuleList() |
| | for i in range(self.start_level, self.backbone_end_level): |
| | l_conv = nn.Conv2d(self.in_channels[i], |
| | self.inter_channels[i - self.start_level], 1) |
| | self.lateral_convs.append(l_conv) |
| |
|
| | extra_levels = num_outs - self.backbone_end_level + self.start_level |
| | self.extra_downsamples = nn.ModuleList() |
| | for i in range(extra_levels): |
| | if self.add_extra_convs: |
| | fpn_idx = self.backbone_end_level - self.start_level + i |
| | extra_conv = nn.Conv2d( |
| | self.inter_channels[fpn_idx - 1], |
| | self.inter_channels[fpn_idx], |
| | 3, |
| | stride=2, |
| | padding=1) |
| | self.extra_downsamples.append(extra_conv) |
| | else: |
| | self.extra_downsamples.append(nn.MaxPool2d(1, stride=2)) |
| |
|
| | self.fpn_transitions = nn.ModuleList() |
| | for s in range(self.stack_times): |
| | stage_trans = nn.ModuleList() |
| | for i in range(self.num_outs): |
| | |
| | trans = nn.ModuleDict() |
| | if s in self.skip_inds[i]: |
| | stage_trans.append(trans) |
| | continue |
| | |
| | if i == 0 or self.same_up_trans is None: |
| | same_up_trans = None |
| | else: |
| | same_up_trans = self.build_trans( |
| | self.same_up_trans, self.inter_channels[i - 1], |
| | self.inter_channels[i]) |
| | trans['same_up'] = same_up_trans |
| | |
| | if i == self.num_outs - 1 or self.same_down_trans is None: |
| | same_down_trans = None |
| | else: |
| | same_down_trans = self.build_trans( |
| | self.same_down_trans, self.inter_channels[i + 1], |
| | self.inter_channels[i]) |
| | trans['same_down'] = same_down_trans |
| | |
| | across_lateral_trans = self.build_trans( |
| | self.across_lateral_trans, self.inter_channels[i], |
| | self.inter_channels[i]) |
| | trans['across_lateral'] = across_lateral_trans |
| | |
| | if i == self.num_outs - 1 or self.across_down_trans is None: |
| | across_down_trans = None |
| | else: |
| | across_down_trans = self.build_trans( |
| | self.across_down_trans, self.inter_channels[i + 1], |
| | self.inter_channels[i]) |
| | trans['across_down'] = across_down_trans |
| | |
| | if i == 0 or self.across_up_trans is None: |
| | across_up_trans = None |
| | else: |
| | across_up_trans = self.build_trans( |
| | self.across_up_trans, self.inter_channels[i - 1], |
| | self.inter_channels[i]) |
| | trans['across_up'] = across_up_trans |
| | if self.across_skip_trans is None: |
| | across_skip_trans = None |
| | else: |
| | across_skip_trans = self.build_trans( |
| | self.across_skip_trans, self.inter_channels[i - 1], |
| | self.inter_channels[i]) |
| | trans['across_skip'] = across_skip_trans |
| | |
| | stage_trans.append(trans) |
| | self.fpn_transitions.append(stage_trans) |
| |
|
| | self.output_transition = nn.ModuleList() |
| | for i in range(self.num_outs): |
| | trans = self.build_trans( |
| | self.output_trans, |
| | self.inter_channels[i], |
| | self.out_channels, |
| | num_inputs=self.stack_times + 1) |
| | self.output_transition.append(trans) |
| |
|
| | self.relu = nn.ReLU(inplace=True) |
| |
|
| | def build_trans(self, cfg, in_channels, out_channels, **extra_args): |
| | cfg_ = cfg.copy() |
| | trans_type = cfg_.pop('type') |
| | trans_cls = self.transition_types[trans_type] |
| | return trans_cls(in_channels, out_channels, **cfg_, **extra_args) |
| |
|
| | def fuse(self, fuse_dict): |
| | out = None |
| | for item in fuse_dict.values(): |
| | if item is not None: |
| | if out is None: |
| | out = item |
| | else: |
| | out = out + item |
| | return out |
| |
|
| | def forward(self, inputs): |
| | assert len(inputs) == len(self.in_channels) |
| |
|
| | |
| | feats = [ |
| | lateral_conv(inputs[i + self.start_level]) |
| | for i, lateral_conv in enumerate(self.lateral_convs) |
| | ] |
| | for downsample in self.extra_downsamples: |
| | feats.append(downsample(feats[-1])) |
| |
|
| | outs = [feats] |
| |
|
| | for i in range(self.stack_times): |
| | current_outs = outs[-1] |
| | next_outs = [] |
| | direction = self.paths[i] |
| | for j in range(self.num_outs): |
| | if i in self.skip_inds[j]: |
| | next_outs.append(outs[-1][j]) |
| | continue |
| | |
| | if direction == 'td': |
| | lvl = self.num_outs - j - 1 |
| | else: |
| | lvl = j |
| | |
| | if direction == 'td': |
| | same_trans = self.fpn_transitions[i][lvl]['same_down'] |
| | else: |
| | same_trans = self.fpn_transitions[i][lvl]['same_up'] |
| | across_lateral_trans = self.fpn_transitions[i][lvl][ |
| | 'across_lateral'] |
| | across_down_trans = self.fpn_transitions[i][lvl]['across_down'] |
| | across_up_trans = self.fpn_transitions[i][lvl]['across_up'] |
| | across_skip_trans = self.fpn_transitions[i][lvl]['across_skip'] |
| | |
| | to_fuse = dict( |
| | same=None, lateral=None, across_up=None, across_down=None) |
| | |
| | if same_trans is not None: |
| | to_fuse['same'] = same_trans(next_outs[-1]) |
| | |
| | if across_lateral_trans is not None: |
| | to_fuse['lateral'] = across_lateral_trans( |
| | current_outs[lvl]) |
| | |
| | if lvl > 0 and across_up_trans is not None: |
| | to_fuse['across_up'] = across_up_trans(current_outs[lvl - |
| | 1]) |
| | |
| | if (lvl < self.num_outs - 1 and across_down_trans is not None): |
| | to_fuse['across_down'] = across_down_trans( |
| | current_outs[lvl + 1]) |
| | if across_skip_trans is not None: |
| | to_fuse['across_skip'] = across_skip_trans(outs[0][lvl]) |
| | x = self.fuse(to_fuse) |
| | next_outs.append(x) |
| |
|
| | if direction == 'td': |
| | outs.append(next_outs[::-1]) |
| | else: |
| | outs.append(next_outs) |
| |
|
| | |
| | final_outs = [] |
| | for i in range(self.num_outs): |
| | lvl_out_list = [] |
| | for s in range(len(outs)): |
| | lvl_out_list.append(outs[s][i]) |
| | lvl_out = self.output_transition[i](lvl_out_list) |
| | final_outs.append(lvl_out) |
| |
|
| | return final_outs |
| |
|