| | |
| | import fvcore.nn.weight_init as weight_init |
| | import torch.nn.functional as F |
| |
|
| | from detectron2.layers import CNNBlockBase, Conv2d, get_norm |
| | from detectron2.modeling import BACKBONE_REGISTRY |
| | from detectron2.modeling.backbone.resnet import ( |
| | BasicStem, |
| | BottleneckBlock, |
| | DeformBottleneckBlock, |
| | ResNet, |
| | ) |
| |
|
| |
|
| | class DeepLabStem(CNNBlockBase): |
| | """ |
| | The DeepLab ResNet stem (layers before the first residual block). |
| | """ |
| |
|
| | def __init__(self, in_channels=3, out_channels=128, norm="BN"): |
| | """ |
| | Args: |
| | norm (str or callable): norm after the first conv layer. |
| | See :func:`layers.get_norm` for supported format. |
| | """ |
| | super().__init__(in_channels, out_channels, 4) |
| | self.in_channels = in_channels |
| | self.conv1 = Conv2d( |
| | in_channels, |
| | out_channels // 2, |
| | kernel_size=3, |
| | stride=2, |
| | padding=1, |
| | bias=False, |
| | norm=get_norm(norm, out_channels // 2), |
| | ) |
| | self.conv2 = Conv2d( |
| | out_channels // 2, |
| | out_channels // 2, |
| | kernel_size=3, |
| | stride=1, |
| | padding=1, |
| | bias=False, |
| | norm=get_norm(norm, out_channels // 2), |
| | ) |
| | self.conv3 = Conv2d( |
| | out_channels // 2, |
| | out_channels, |
| | kernel_size=3, |
| | stride=1, |
| | padding=1, |
| | bias=False, |
| | norm=get_norm(norm, out_channels), |
| | ) |
| | weight_init.c2_msra_fill(self.conv1) |
| | weight_init.c2_msra_fill(self.conv2) |
| | weight_init.c2_msra_fill(self.conv3) |
| |
|
| | def forward(self, x): |
| | x = self.conv1(x) |
| | x = F.relu_(x) |
| | x = self.conv2(x) |
| | x = F.relu_(x) |
| | x = self.conv3(x) |
| | x = F.relu_(x) |
| | x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) |
| | return x |
| |
|
| |
|
| | @BACKBONE_REGISTRY.register() |
| | def build_resnet_deeplab_backbone(cfg, input_shape): |
| | """ |
| | Create a ResNet instance from config. |
| | Returns: |
| | ResNet: a :class:`ResNet` instance. |
| | """ |
| | |
| | norm = cfg.MODEL.RESNETS.NORM |
| | if cfg.MODEL.RESNETS.STEM_TYPE == "basic": |
| | stem = BasicStem( |
| | in_channels=input_shape.channels, |
| | out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS, |
| | norm=norm, |
| | ) |
| | elif cfg.MODEL.RESNETS.STEM_TYPE == "deeplab": |
| | stem = DeepLabStem( |
| | in_channels=input_shape.channels, |
| | out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS, |
| | norm=norm, |
| | ) |
| | else: |
| | raise ValueError("Unknown stem type: {}".format(cfg.MODEL.RESNETS.STEM_TYPE)) |
| |
|
| | |
| | freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT |
| | out_features = cfg.MODEL.RESNETS.OUT_FEATURES |
| | depth = cfg.MODEL.RESNETS.DEPTH |
| | num_groups = cfg.MODEL.RESNETS.NUM_GROUPS |
| | width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP |
| | bottleneck_channels = num_groups * width_per_group |
| | in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS |
| | out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS |
| | stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1 |
| | res4_dilation = cfg.MODEL.RESNETS.RES4_DILATION |
| | res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION |
| | deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE |
| | deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED |
| | deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS |
| | res5_multi_grid = cfg.MODEL.RESNETS.RES5_MULTI_GRID |
| | |
| | assert res4_dilation in {1, 2}, "res4_dilation cannot be {}.".format(res4_dilation) |
| | assert res5_dilation in {1, 2, 4}, "res5_dilation cannot be {}.".format(res5_dilation) |
| | if res4_dilation == 2: |
| | |
| | assert res5_dilation == 4 |
| |
|
| | num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth] |
| |
|
| | stages = [] |
| |
|
| | |
| | |
| | out_stage_idx = [{"res2": 2, "res3": 3, "res4": 4, "res5": 5}[f] for f in out_features] |
| | max_stage_idx = max(out_stage_idx) |
| | for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)): |
| | if stage_idx == 4: |
| | dilation = res4_dilation |
| | elif stage_idx == 5: |
| | dilation = res5_dilation |
| | else: |
| | dilation = 1 |
| | first_stride = 1 if idx == 0 or dilation > 1 else 2 |
| | stage_kargs = { |
| | "num_blocks": num_blocks_per_stage[idx], |
| | "stride_per_block": [first_stride] + [1] * (num_blocks_per_stage[idx] - 1), |
| | "in_channels": in_channels, |
| | "out_channels": out_channels, |
| | "norm": norm, |
| | } |
| | stage_kargs["bottleneck_channels"] = bottleneck_channels |
| | stage_kargs["stride_in_1x1"] = stride_in_1x1 |
| | stage_kargs["dilation"] = dilation |
| | stage_kargs["num_groups"] = num_groups |
| | if deform_on_per_stage[idx]: |
| | stage_kargs["block_class"] = DeformBottleneckBlock |
| | stage_kargs["deform_modulated"] = deform_modulated |
| | stage_kargs["deform_num_groups"] = deform_num_groups |
| | else: |
| | stage_kargs["block_class"] = BottleneckBlock |
| | if stage_idx == 5: |
| | stage_kargs.pop("dilation") |
| | stage_kargs["dilation_per_block"] = [dilation * mg for mg in res5_multi_grid] |
| | blocks = ResNet.make_stage(**stage_kargs) |
| | in_channels = out_channels |
| | out_channels *= 2 |
| | bottleneck_channels *= 2 |
| | stages.append(blocks) |
| | return ResNet(stem, stages, out_features=out_features).freeze(freeze_at) |
| |
|