File size: 3,576 Bytes
226675b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
from torch.optim import lr_scheduler
from rscd.models.backbones import resnet_bit
class BIT_Backbone(torch.nn.Module):
def __init__(self, input_nc, output_nc,
resnet_stages_num=5, backbone='resnet18',
if_upsample_2x=True):
"""
In the constructor we instantiate two nn.Linear modules and assign them as
member variables.
"""
super(BIT_Backbone, self).__init__()
expand = 1
if backbone == 'resnet18':
self.resnet = resnet_bit.resnet18(pretrained=True,
replace_stride_with_dilation=[False,True,True])
elif backbone == 'resnet34':
self.resnet = resnet_bit.resnet34(pretrained=True,
replace_stride_with_dilation=[False,True,True])
elif backbone == 'resnet50':
self.resnet = resnet_bit.resnet50(pretrained=True,
replace_stride_with_dilation=[False,True,True])
expand = 4
else:
raise NotImplementedError
self.upsamplex2 = nn.Upsample(scale_factor=2)
self.resnet_stages_num = resnet_stages_num
self.if_upsample_2x = if_upsample_2x
if self.resnet_stages_num == 5:
layers = 512 * expand
elif self.resnet_stages_num == 4:
layers = 256 * expand
elif self.resnet_stages_num == 3:
layers = 128 * expand
else:
raise NotImplementedError
self.conv_pred = nn.Conv2d(layers, output_nc, kernel_size=3, padding=1)
def forward(self, x1, x2):
# forward backbone resnet
x1 = self.forward_single(x1)
x2 = self.forward_single(x2)
return [x1, x2]
def forward_single(self, x):
# resnet layers
x = self.resnet.conv1(x)
x = self.resnet.bn1(x)
x = self.resnet.relu(x)
x = self.resnet.maxpool(x)
x_4 = self.resnet.layer1(x) # 1/4, in=64, out=64
x_8 = self.resnet.layer2(x_4) # 1/8, in=64, out=128
if self.resnet_stages_num > 3:
x_8 = self.resnet.layer3(x_8) # 1/8, in=128, out=256
if self.resnet_stages_num == 5:
x_8 = self.resnet.layer4(x_8) # 1/32, in=256, out=512
elif self.resnet_stages_num > 5:
raise NotImplementedError
if self.if_upsample_2x:
x = self.upsamplex2(x_8)
else:
x = x_8
# output layers
x = self.conv_pred(x)
return x
def BIT_backbone_func(cfg):
net = BIT_Backbone(input_nc=cfg.input_nc,
output_nc=cfg.output_nc,
resnet_stages_num=cfg.resnet_stages_num,
backbone=cfg.backbone,
if_upsample_2x=cfg.if_upsample_2x)
return net
if __name__ == '__main__':
x1 = torch.rand(4, 3, 512, 512)
x2 = torch.rand(4, 3, 512, 512)
cfg = dict(
type = 'BIT_Backbone',
input_nc=3,
output_nc=32,
resnet_stages_num=4,
backbone='resnet18',
if_upsample_2x=True,
)
from munch import DefaultMunch
cfg = DefaultMunch.fromDict(cfg)
model = BIT_backbone_func(cfg)
model.eval()
print(model)
outs = model(x1, x2)
print('BIT', outs)
for out in outs:
print(out.shape) |