| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | |
| | |
| | |
| | from itertools import chain |
| | |
| | from einops import rearrange |
| | from torch.hub import load_state_dict_from_url |
| |
|
| | GlobalAvgPool2D = lambda: nn.AdaptiveAvgPool2d(1) |
| |
|
| | model_urls = { |
| | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', |
| | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', |
| | } |
| |
|
| | class Cross_transformer_backbone(nn.Module): |
| | def __init__(self, in_channels = 48): |
| | super(Cross_transformer_backbone, self).__init__() |
| | |
| | self.to_key = nn.Linear(in_channels * 2, in_channels, bias=False) |
| | self.to_value = nn.Linear(in_channels * 2, in_channels, bias=False) |
| | self.softmax = nn.Softmax(dim=-1) |
| |
|
| | self.gamma_cam_lay3 = nn.Parameter(torch.zeros(1)) |
| | self.cam_layer0 = nn.Sequential( |
| | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(in_channels), |
| | nn.ReLU() |
| | ) |
| | self.cam_layer1 = nn.Sequential( |
| | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False), |
| | nn.BatchNorm2d(in_channels), |
| | nn.ReLU(inplace=True) |
| | ) |
| | self.cam_layer2 = nn.Sequential( |
| | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False), |
| | nn.BatchNorm2d(in_channels), |
| | nn.ReLU(inplace=True) |
| | ) |
| | self.mlp = nn.Sequential( |
| | nn.Conv2d(in_channels*2, in_channels, kernel_size=3, padding=1, bias=False), |
| | nn.BatchNorm2d(in_channels), |
| | nn.ReLU(inplace=True) |
| | ) |
| | |
| | def forward(self, input_feature, features): |
| | Query_features = input_feature |
| | Query_features = self.cam_layer0(Query_features) |
| | key_features = self.cam_layer1(features) |
| | value_features = self.cam_layer2(features) |
| | |
| | QK = torch.einsum("nlhd,nshd->nlsh", Query_features, key_features) |
| | softmax_temp = 1. / Query_features.size(3)**.5 |
| | A = torch.softmax(softmax_temp * QK, dim=2) |
| | queried_values = torch.einsum("nlsh,nshd->nlhd", A, value_features).contiguous() |
| | message = self.mlp(torch.cat([input_feature, queried_values], dim=1)) |
| | |
| | return input_feature + message |
| |
|
| | class Cross_transformer(nn.Module): |
| | def __init__(self, in_channels = 48): |
| | super(Cross_transformer, self).__init__() |
| | self.fa = nn.Linear(in_channels , in_channels, bias=False) |
| | self.fb = nn.Linear(in_channels, in_channels, bias=False) |
| | self.fc = nn.Linear(in_channels , in_channels, bias=False) |
| | self.fd = nn.Linear(in_channels, in_channels, bias=False) |
| | self.softmax = nn.Softmax(dim=-1) |
| | self.to_out = nn.Sequential( |
| | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(in_channels), |
| | nn.ReLU(inplace=True) |
| | ) |
| | self.gamma_cam_lay3 = nn.Parameter(torch.zeros(1)) |
| | self.fuse = nn.Sequential( |
| | nn.Conv2d(in_channels * 4, in_channels, kernel_size=1, padding=0), |
| | nn.BatchNorm2d(in_channels), |
| | nn.ReLU(inplace=True) |
| | ) |
| | |
| | |
| | def attention_layer(self, q, k, v, m_batchsize, C, height, width): |
| | k = k.permute(0, 2, 1) |
| | energy = torch.bmm(q, k) |
| | energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy |
| | attention = self.softmax(energy_new) |
| | out = torch.bmm(attention, v) |
| | out = out.view(m_batchsize, C, height, width) |
| | |
| | return out |
| | |
| | |
| | def forward(self, input_feature, features): |
| | fa = input_feature |
| | fb = features[0] |
| | fc = features[1] |
| | fd = features[2] |
| | |
| |
|
| | m_batchsize, C, height, width = fa.size() |
| | fa = self.fa(fa.view(m_batchsize, C, -1).permute(0, 2, 1)).permute(0, 2, 1) |
| | fb = self.fb(fb.view(m_batchsize, C, -1).permute(0, 2, 1)).permute(0, 2, 1) |
| | fc = self.fc(fc.view(m_batchsize, C, -1).permute(0, 2, 1)).permute(0, 2, 1) |
| | fd = self.fd(fd.view(m_batchsize, C, -1).permute(0, 2, 1)).permute(0, 2, 1) |
| | |
| | |
| | qkv_1 = self.attention_layer(fa, fa, fa, m_batchsize, C, height, width) |
| | qkv_2 = self.attention_layer(fa, fb, fb, m_batchsize, C, height, width) |
| | qkv_3 = self.attention_layer(fa, fc, fc, m_batchsize, C, height, width) |
| | qkv_4 = self.attention_layer(fa, fd, fd, m_batchsize, C, height, width) |
| | |
| | atten = self.fuse(torch.cat((qkv_1, qkv_2, qkv_3, qkv_4), dim = 1)) |
| | |
| |
|
| | out = self.gamma_cam_lay3 * atten + input_feature |
| |
|
| | out = self.to_out(out) |
| | |
| | return out |
| |
|
| |
|
| | class SceneRelation(nn.Module): |
| | def __init__(self, |
| | in_channels, |
| | channel_list, |
| | out_channels, |
| | scale_aware_proj=True): |
| | super(SceneRelation, self).__init__() |
| | self.scale_aware_proj = scale_aware_proj |
| |
|
| | if scale_aware_proj: |
| | self.scene_encoder = nn.ModuleList( |
| | [nn.Sequential( |
| | nn.Conv2d(in_channels, out_channels, 1), |
| | nn.ReLU(True), |
| | nn.Conv2d(out_channels, out_channels, 1), |
| | ) for _ in range(len(channel_list))] |
| | ) |
| | else: |
| | |
| | self.scene_encoder = nn.Sequential( |
| | nn.Conv2d(in_channels, out_channels, 1), |
| | nn.ReLU(True), |
| | nn.Conv2d(out_channels, out_channels, 1), |
| | ) |
| | self.content_encoders = nn.ModuleList() |
| | self.feature_reencoders = nn.ModuleList() |
| | for c in channel_list: |
| | self.content_encoders.append( |
| | nn.Sequential( |
| | nn.Conv2d(c, out_channels, 1), |
| | nn.BatchNorm2d(out_channels), |
| | nn.ReLU(True) |
| | ) |
| | ) |
| | self.feature_reencoders.append( |
| | nn.Sequential( |
| | nn.Conv2d(c, out_channels, 1), |
| | nn.BatchNorm2d(out_channels), |
| | nn.ReLU(True) |
| | ) |
| | ) |
| |
|
| | self.normalizer = nn.Sigmoid() |
| | |
| | |
| |
|
| | def forward(self, scene_feature, features: list): |
| | content_feats = [c_en(p_feat) for c_en, p_feat in zip(self.content_encoders, features)] |
| |
|
| | scene_feats = [op(scene_feature) for op in self.scene_encoder] |
| | relations = [self.normalizer(sf) * cf for sf, cf in |
| | zip(scene_feats, content_feats)] |
| |
|
| | |
| | return relations |
| |
|
| | class PSPModule(nn.Module): |
| | def __init__(self, in_channels, bin_sizes=[1, 2, 4, 6]): |
| | super(PSPModule, self).__init__() |
| | out_channels = in_channels // len(bin_sizes) |
| | self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, b_s) |
| | for b_s in bin_sizes]) |
| | self.bottleneck = nn.Sequential( |
| | nn.Conv2d(in_channels+(out_channels * len(bin_sizes)), in_channels, |
| | kernel_size=3, padding=1, bias=False), |
| | nn.BatchNorm2d(in_channels), |
| | nn.ReLU(inplace=True), |
| | nn.Dropout2d(0.1) |
| | ) |
| |
|
| | def _make_stages(self, in_channels, out_channels, bin_sz): |
| | conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) |
| | bn = nn.BatchNorm2d(out_channels) |
| | relu = nn.ReLU(inplace=True) |
| | return nn.Sequential(conv, bn, relu) |
| | |
| | def forward(self, features): |
| | h, w = features.size()[2], features.size()[3] |
| | pyramids = [features] |
| |
|
| | pyramids.extend([F.interpolate(stage(features), size=(h, w), mode='bilinear', |
| | align_corners=True) for stage in self.stages]) |
| | output = self.bottleneck(torch.cat(pyramids, dim=1)) |
| | return output |
| |
|
| | class Change_detection(nn.Module): |
| | |
| | def __init__(self, num_classes=2, use_aux=True, fpn_out=48, freeze_bn=False, **_): |
| | super(Change_detection, self).__init__() |
| |
|
| | f_channels = [64, 128, 256, 512] |
| |
|
| | |
| | self.PPN = PSPModule(f_channels[-1]) |
| | |
| | |
| | self.Cross_transformer_backbone_a3 = Cross_transformer_backbone(in_channels = f_channels[3]) |
| | self.Cross_transformer_backbone_a2 = Cross_transformer_backbone(in_channels = f_channels[2]) |
| | self.Cross_transformer_backbone_a1 = Cross_transformer_backbone(in_channels = f_channels[1]) |
| | self.Cross_transformer_backbone_a0 = Cross_transformer_backbone(in_channels = f_channels[0]) |
| | self.Cross_transformer_backbone_a33 = Cross_transformer_backbone(in_channels = f_channels[3]) |
| | self.Cross_transformer_backbone_a22 = Cross_transformer_backbone(in_channels = f_channels[2]) |
| | self.Cross_transformer_backbone_a11 = Cross_transformer_backbone(in_channels = f_channels[1]) |
| | self.Cross_transformer_backbone_a00 = Cross_transformer_backbone(in_channels = f_channels[0]) |
| | |
| | self.Cross_transformer_backbone_b3 = Cross_transformer_backbone(in_channels = f_channels[3]) |
| | self.Cross_transformer_backbone_b2 = Cross_transformer_backbone(in_channels = f_channels[2]) |
| | self.Cross_transformer_backbone_b1 = Cross_transformer_backbone(in_channels = f_channels[1]) |
| | self.Cross_transformer_backbone_b0 = Cross_transformer_backbone(in_channels = f_channels[0]) |
| | self.Cross_transformer_backbone_b33 = Cross_transformer_backbone(in_channels = f_channels[3]) |
| | self.Cross_transformer_backbone_b22 = Cross_transformer_backbone(in_channels = f_channels[2]) |
| | self.Cross_transformer_backbone_b11 = Cross_transformer_backbone(in_channels = f_channels[1]) |
| | self.Cross_transformer_backbone_b00 = Cross_transformer_backbone(in_channels = f_channels[0]) |
| |
|
| |
|
| | |
| | self.sig = nn.Sigmoid() |
| | self.gap = GlobalAvgPool2D() |
| | self.sr1 = SceneRelation(in_channels = f_channels[3], channel_list = f_channels, out_channels = f_channels[3], scale_aware_proj=True) |
| | self.sr2 = SceneRelation(in_channels = f_channels[2], channel_list = f_channels, out_channels = f_channels[2], scale_aware_proj=True) |
| | self.sr3 = SceneRelation(in_channels = f_channels[1], channel_list = f_channels, out_channels = f_channels[1], scale_aware_proj=True) |
| | self.sr4 = SceneRelation(in_channels = f_channels[0], channel_list =f_channels, out_channels = f_channels[0], scale_aware_proj=True) |
| |
|
| |
|
| | |
| | self.Cross_transformer1 = Cross_transformer(in_channels = f_channels[3]) |
| | self.Cross_transformer2 = Cross_transformer(in_channels = f_channels[2]) |
| | self.Cross_transformer3 = Cross_transformer(in_channels = f_channels[1]) |
| | self.Cross_transformer4 = Cross_transformer(in_channels = f_channels[0]) |
| |
|
| |
|
| | |
| | self.conv_fusion = nn.Sequential( |
| | nn.Conv2d(960 , fpn_out, kernel_size=3, padding=1, bias=False), |
| | nn.BatchNorm2d(fpn_out), |
| | nn.ReLU(inplace=True) |
| | ) |
| | |
| | self.output_fill = nn.Sequential( |
| | nn.ConvTranspose2d(fpn_out , fpn_out, kernel_size=2, stride = 2, bias=False), |
| | nn.BatchNorm2d(fpn_out), |
| | nn.ReLU(inplace=True), |
| | nn.Conv2d(fpn_out, num_classes, kernel_size=3, padding=1) |
| | ) |
| | self.active = nn.Sigmoid() |
| |
|
| | def forward(self, x): |
| | |
| | features1, features2 = x |
| | |
| | features, features11, features22= [], [],[] |
| |
|
| | |
| | for i in range(len(features1)): |
| | if i == 0: |
| | features11.append(self.Cross_transformer_backbone_a00(features1[i] , self.Cross_transformer_backbone_a0(features1[i], features2[i]))) |
| | features22.append(self.Cross_transformer_backbone_b00(features2[i], self.Cross_transformer_backbone_b0(features2[i], features1[i]))) |
| | elif i == 1: |
| | features11.append(self.Cross_transformer_backbone_a11(features1[i] , self.Cross_transformer_backbone_a1(features1[i], features2[i]))) |
| | features22.append(self.Cross_transformer_backbone_b11(features2[i], self.Cross_transformer_backbone_b1(features2[i], features1[i]))) |
| | elif i == 2: |
| | features11.append(self.Cross_transformer_backbone_a22(features1[i] , self.Cross_transformer_backbone_a2(features1[i], features2[i]))) |
| | features22.append(self.Cross_transformer_backbone_b22(features2[i], self.Cross_transformer_backbone_b2(features2[i], features1[i]))) |
| | elif i == 3: |
| | features11.append(self.Cross_transformer_backbone_a33(features1[i] , self.Cross_transformer_backbone_a3(features1[i], features2[i]))) |
| | features22.append(self.Cross_transformer_backbone_b33(features2[i], self.Cross_transformer_backbone_b3(features2[i], features1[i]))) |
| | |
| | |
| | for i in range(len(features1)): |
| | features.append(abs(features11[i] - features22[i])) |
| | features[-1] = self.PPN(features[-1]) |
| |
|
| |
|
| | |
| | H, W = features[0].size(2), features[0].size(3) |
| | |
| | c6 = self.gap(features[-1]) |
| | c7 = self.gap(features[-2]) |
| | c8 = self.gap(features[-3]) |
| | c9 = self.gap(features[-4]) |
| | |
| | features1, features2, features3, features4 = [], [], [], [] |
| | features1[:] = [F.interpolate(feature, size=(64, 64), mode='nearest') for feature in features[:]] |
| | list_3 = self.sr1(c6, features1) |
| | fe3 = self.Cross_transformer1(list_3[-1], [list_3[-2], list_3[-3], list_3[-4]]) |
| | |
| | features2[:] = [F.interpolate(feature, size=(64, 64), mode='nearest') for feature in features[:]] |
| | list_2 = self.sr2(c7, features2) |
| | fe2 = self.Cross_transformer2(list_2[-2], [list_2[-1], list_2[-3], list_2[-4]]) |
| | |
| | features3[:] = [F.interpolate(feature, size=(64, 64), mode='nearest') for feature in features[:]] |
| | list_1 = self.sr3(c8, features3) |
| | fe1 = self.Cross_transformer3(list_1[-3], [list_1[-1], list_1[-2], list_1[-4]]) |
| | |
| | features4[:] = [F.interpolate(feature, size=(128, 128), mode='nearest') for feature in features[:]] |
| | list_0 = self.sr4(c9, features4) |
| | fe0 = self.Cross_transformer4(list_0[-4], [list_0[-1], list_0[-2], list_0[-3]]) |
| |
|
| | refined_fpn_feat_list = [fe3, fe2, fe1, fe0] |
| | |
| | |
| | refined_fpn_feat_list[0] = F.interpolate(refined_fpn_feat_list[0], scale_factor=4, mode='nearest') |
| | refined_fpn_feat_list[1] = F.interpolate(refined_fpn_feat_list[1], scale_factor=4, mode='nearest') |
| | refined_fpn_feat_list[2] = F.interpolate(refined_fpn_feat_list[2], scale_factor=4, mode='nearest') |
| | refined_fpn_feat_list[3] = F.interpolate(refined_fpn_feat_list[3], scale_factor=2, mode='nearest') |
| |
|
| | |
| | x = self.conv_fusion(torch.cat((refined_fpn_feat_list), dim=1)) |
| | x = self.output_fill(x) |
| |
|
| | return x |
| |
|
| |
|
| | if __name__ == '__main__': |
| | xa = torch.randn(4, 3, 256, 256) |
| | xb = torch.randn(4, 3, 256, 256) |
| | net = Change_detection() |
| | out = net(xa, xb) |
| | print(out.shape) |