File size: 7,987 Bytes
9855f47 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 | # ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
"""
Backbone modules.
"""
from collections import OrderedDict
from functools import partial
from typing import Dict, List
import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torch.cuda.amp import autocast
from torchvision.models._utils import IntermediateLayerGetter
from util.misc import is_main_process, NestedTensor
from .position_encoding import build_position_encoding
from .swin import get_swinl
from .pev1 import get_pev1_and_fpn_backbone
class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other models than torchvision.models.resnet[18,34,50,101]
produce nans.
"""
def __init__(self, n, eps=1e-5):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
self.eps = eps
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = self.eps
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
return x * scale + bias
class BackboneBase(nn.Module):
def __init__(
self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool
):
super().__init__()
for name, parameter in backbone.named_parameters():
if (
not train_backbone
or "layer2" not in name
and "layer3" not in name
and "layer4" not in name
):
parameter.requires_grad_(False)
if return_interm_layers:
# return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
self.strides = [8, 16, 32]
self.num_channels = [512, 1024, 2048]
else:
return_layers = {"layer4": "0"}
self.strides = [32]
self.num_channels = [2048]
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
def forward(self, tensor_list: NestedTensor):
xs = self.body(tensor_list.tensors)
out: Dict[str, NestedTensor] = {}
for name, x in xs.items():
m = tensor_list.mask
assert m is not None
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
out[name] = NestedTensor(x, mask)
return out
class Backbone(BackboneBase):
"""ResNet backbone with frozen BatchNorm."""
def __init__(
self,
name: str,
train_backbone: bool,
return_interm_layers: bool,
dilation: bool,
):
norm_layer = FrozenBatchNorm2d
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(),
norm_layer=norm_layer,
)
assert name not in ("resnet18", "resnet34"), "number of channels are hard coded"
super().__init__(backbone, train_backbone, return_interm_layers)
if dilation:
self.strides[-1] = self.strides[-1] // 2
class SwinBackbone(nn.Module):
def __init__(self):
# we skip R50 FrozenBatchNorm2d, dilation, train l{2,3,4} only
super().__init__()
self.body = get_swinl()
self.features = ["res3", "res4", "res5"]
self.strides = [8, 16, 32]
self.num_channels = [384, 768, 1536]
def forward(self, tensor_list: NestedTensor):
xs = self.body(tensor_list.tensors)
m = tensor_list.mask[None]
assert m is not None
out: Dict[str, NestedTensor] = {}
for name in self.features:
mask = F.interpolate(m.float(), size=xs[name].shape[-2:]).to(torch.bool)[0]
out[name] = NestedTensor(xs[name], mask)
return out
class PEv1Backbone(nn.Module):
def __init__(self, args):
super().__init__()
self.body = get_pev1_and_fpn_backbone(args)
self.features = self.body._out_features
self.bf16 = args.bf16
self.fp16 = args.fp16
_out_feature_strides = self.body._out_feature_strides
_out_feature_channels = self.body._out_feature_channels
self.strides = [_out_feature_strides[f] for f in _out_feature_strides.keys()]
self.num_channels = [
_out_feature_channels[f] for f in _out_feature_channels.keys()
]
def forward(self, tensor_list: NestedTensor):
# xs = self.body(tensor_list.tensors)
# backbone
if self.bf16:
with autocast(dtype=torch.bfloat16):
xs = self.body(tensor_list.tensors.to(torch.bfloat16))
xs = {k: v.float() for k, v in xs.items()}
elif self.fp16:
with autocast(dtype=torch.float16):
xs = self.body(tensor_list.tensors.half())
xs = {k: v.float() for k, v in xs.items()}
else:
xs = self.body(tensor_list.tensors)
m = tensor_list.mask[None]
assert m is not None
out: Dict[str, NestedTensor] = {}
for name in self.features:
mask = F.interpolate(m.float(), size=xs[name].shape[-2:]).to(torch.bool)[0]
out[name] = NestedTensor(xs[name], mask)
return out
class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)
self.strides = backbone.strides
self.num_channels = backbone.num_channels
def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out: List[NestedTensor] = []
pos = []
for name, x in sorted(xs.items()):
out.append(x)
# position encoding
for x in out:
pos.append(self[1](x).to(x.tensors.dtype))
return out, pos
def build_backbone(args):
position_embedding = build_position_encoding(args)
train_backbone = args.lr_backbone > 0
return_interm_layers = args.masks or (args.num_feature_levels > 1)
if "swin" in args.backbone:
backbone = SwinBackbone()
elif "pev1" in args.backbone:
backbone = PEv1Backbone(args)
else:
backbone = Backbone(
args.backbone, train_backbone, return_interm_layers, args.dilation
)
model = Joiner(backbone, position_embedding)
return model
|