| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import warnings
|
|
|
|
|
| class MLP(nn.Module):
|
| """
|
| Linear Embedding
|
| """
|
| def __init__(self, input_dim=2048, embed_dim=768):
|
| super().__init__()
|
| self.proj = nn.Linear(input_dim, embed_dim)
|
|
|
| def forward(self, x):
|
| x = x.flatten(2).transpose(1, 2)
|
| x = self.proj(x)
|
| return x
|
|
|
| class UpsampleConvLayer(torch.nn.Module):
|
| def __init__(self, in_channels, out_channels, kernel_size, stride):
|
| super(UpsampleConvLayer, self).__init__()
|
| self.conv2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=1)
|
|
|
| def forward(self, x):
|
| out = self.conv2d(x)
|
| return out
|
|
|
| class ResidualBlock(torch.nn.Module):
|
| def __init__(self, channels):
|
| super(ResidualBlock, self).__init__()
|
| self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1, padding=1)
|
| self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1, padding=1)
|
| self.relu = nn.ReLU()
|
|
|
| def forward(self, x):
|
| residual = x
|
| out = self.relu(self.conv1(x))
|
| out = self.conv2(out) * 0.1
|
| out = torch.add(out, residual)
|
| return out
|
|
|
| class ConvLayer(nn.Module):
|
| def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
|
| super(ConvLayer, self).__init__()
|
|
|
|
|
| self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
|
|
|
| def forward(self, x):
|
|
|
| out = self.conv2d(x)
|
| return out
|
|
|
|
|
|
|
| def conv_diff(in_channels, out_channels):
|
| return nn.Sequential(
|
| nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
| nn.ReLU(),
|
| nn.BatchNorm2d(out_channels),
|
| nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
| nn.ReLU()
|
| )
|
|
|
|
|
| def make_prediction(in_channels, out_channels):
|
| return nn.Sequential(
|
| nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
| nn.ReLU(),
|
| nn.BatchNorm2d(out_channels),
|
| nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
| )
|
|
|
| def resize(input,
|
| size=None,
|
| scale_factor=None,
|
| mode='nearest',
|
| align_corners=None,
|
| warning=True):
|
| if warning:
|
| if size is not None and align_corners:
|
| input_h, input_w = tuple(int(x) for x in input.shape[2:])
|
| output_h, output_w = tuple(int(x) for x in size)
|
| if output_h > input_h or output_w > output_h:
|
| if ((output_h > 1 and output_w > 1 and input_h > 1
|
| and input_w > 1) and (output_h - 1) % (input_h - 1)
|
| and (output_w - 1) % (input_w - 1)):
|
| warnings.warn(
|
| f'When align_corners={align_corners}, '
|
| 'the output would more aligned if '
|
| f'input size {(input_h, input_w)} is `x+1` and '
|
| f'out size {(output_h, output_w)} is `nx+1`')
|
| return F.interpolate(input, size, scale_factor, mode, align_corners)
|
|
|
|
|
| class DecoderTransformer_v3(nn.Module):
|
| """
|
| Transformer Decoder
|
| """
|
| def __init__(self, input_transform='multiple_select', in_index=[0, 1, 2, 3], align_corners=True,
|
| in_channels = [32, 64, 128, 256], embedding_dim= 64, output_nc=2,
|
| decoder_softmax = False, feature_strides=[2, 4, 8, 16]):
|
| super(DecoderTransformer_v3, self).__init__()
|
|
|
| assert len(feature_strides) == len(in_channels)
|
| assert min(feature_strides) == feature_strides[0]
|
|
|
|
|
| self.feature_strides = feature_strides
|
| self.input_transform = input_transform
|
| self.in_index = in_index
|
| self.align_corners = align_corners
|
| self.in_channels = in_channels
|
| self.embedding_dim = embedding_dim
|
| self.output_nc = output_nc
|
| c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
|
|
|
|
|
| self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=self.embedding_dim)
|
| self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=self.embedding_dim)
|
| self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=self.embedding_dim)
|
| self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=self.embedding_dim)
|
|
|
|
|
| self.diff_c4 = conv_diff(in_channels=2*self.embedding_dim, out_channels=self.embedding_dim)
|
| self.diff_c3 = conv_diff(in_channels=2*self.embedding_dim, out_channels=self.embedding_dim)
|
| self.diff_c2 = conv_diff(in_channels=2*self.embedding_dim, out_channels=self.embedding_dim)
|
| self.diff_c1 = conv_diff(in_channels=2*self.embedding_dim, out_channels=self.embedding_dim)
|
|
|
|
|
| self.make_pred_c4 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc)
|
| self.make_pred_c3 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc)
|
| self.make_pred_c2 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc)
|
| self.make_pred_c1 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc)
|
|
|
|
|
| self.linear_fuse = nn.Sequential(
|
| nn.Conv2d( in_channels=self.embedding_dim*len(in_channels), out_channels=self.embedding_dim,
|
| kernel_size=1),
|
| nn.BatchNorm2d(self.embedding_dim)
|
| )
|
|
|
|
|
| self.convd2x = UpsampleConvLayer(self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2)
|
| self.dense_2x = nn.Sequential( ResidualBlock(self.embedding_dim))
|
| self.convd1x = UpsampleConvLayer(self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2)
|
| self.dense_1x = nn.Sequential( ResidualBlock(self.embedding_dim))
|
| self.change_probability = ConvLayer(self.embedding_dim, self.output_nc, kernel_size=3, stride=1, padding=1)
|
|
|
|
|
| self.output_softmax = decoder_softmax
|
| self.active = nn.Sigmoid()
|
|
|
| def _transform_inputs(self, inputs):
|
| """Transform inputs for decoder.
|
| Args:
|
| inputs (list[Tensor]): List of multi-level img features.
|
| Returns:
|
| Tensor: The transformed inputs
|
| """
|
|
|
| if self.input_transform == 'resize_concat':
|
| inputs = [inputs[i] for i in self.in_index]
|
| upsampled_inputs = [
|
| resize(
|
| input=x,
|
| size=inputs[0].shape[2:],
|
| mode='bilinear',
|
| align_corners=self.align_corners) for x in inputs
|
| ]
|
| inputs = torch.cat(upsampled_inputs, dim=1)
|
| elif self.input_transform == 'multiple_select':
|
| inputs = [inputs[i] for i in self.in_index]
|
| else:
|
| inputs = inputs[self.in_index]
|
|
|
| return inputs
|
|
|
| def forward(self, inputs1, inputs2):
|
|
|
| x_1 = self._transform_inputs(inputs1)
|
| x_2 = self._transform_inputs(inputs2)
|
|
|
|
|
| c1_1, c2_1, c3_1, c4_1 = x_1
|
| c1_2, c2_2, c3_2, c4_2 = x_2
|
|
|
|
|
| n, _, h, w = c4_1.shape
|
|
|
| outputs = []
|
|
|
| _c4_1 = self.linear_c4(c4_1).permute(0,2,1).reshape(n, -1, c4_1.shape[2], c4_1.shape[3])
|
| _c4_2 = self.linear_c4(c4_2).permute(0,2,1).reshape(n, -1, c4_2.shape[2], c4_2.shape[3])
|
| _c4 = self.diff_c4(torch.cat((_c4_1, _c4_2), dim=1))
|
| p_c4 = self.make_pred_c4(_c4)
|
| outputs.append(p_c4)
|
| _c4_up= resize(_c4, size=c1_2.size()[2:], mode='bilinear', align_corners=False)
|
|
|
|
|
| _c3_1 = self.linear_c3(c3_1).permute(0,2,1).reshape(n, -1, c3_1.shape[2], c3_1.shape[3])
|
| _c3_2 = self.linear_c3(c3_2).permute(0,2,1).reshape(n, -1, c3_2.shape[2], c3_2.shape[3])
|
| _c3 = self.diff_c3(torch.cat((_c3_1, _c3_2), dim=1)) + F.interpolate(_c4, scale_factor=2, mode="bilinear")
|
| p_c3 = self.make_pred_c3(_c3)
|
| outputs.append(p_c3)
|
| _c3_up= resize(_c3, size=c1_2.size()[2:], mode='bilinear', align_corners=False)
|
|
|
|
|
| _c2_1 = self.linear_c2(c2_1).permute(0,2,1).reshape(n, -1, c2_1.shape[2], c2_1.shape[3])
|
| _c2_2 = self.linear_c2(c2_2).permute(0,2,1).reshape(n, -1, c2_2.shape[2], c2_2.shape[3])
|
| _c2 = self.diff_c2(torch.cat((_c2_1, _c2_2), dim=1)) + F.interpolate(_c3, scale_factor=2, mode="bilinear")
|
| p_c2 = self.make_pred_c2(_c2)
|
| outputs.append(p_c2)
|
| _c2_up= resize(_c2, size=c1_2.size()[2:], mode='bilinear', align_corners=False)
|
|
|
|
|
| _c1_1 = self.linear_c1(c1_1).permute(0,2,1).reshape(n, -1, c1_1.shape[2], c1_1.shape[3])
|
| _c1_2 = self.linear_c1(c1_2).permute(0,2,1).reshape(n, -1, c1_2.shape[2], c1_2.shape[3])
|
| _c1 = self.diff_c1(torch.cat((_c1_1, _c1_2), dim=1)) + F.interpolate(_c2, scale_factor=2, mode="bilinear")
|
| p_c1 = self.make_pred_c1(_c1)
|
| outputs.append(p_c1)
|
|
|
|
|
| _c = self.linear_fuse(torch.cat((_c4_up, _c3_up, _c2_up, _c1), dim=1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| x = self.convd2x(_c)
|
|
|
| x = self.dense_2x(x)
|
|
|
| x = self.convd1x(x)
|
|
|
| x = self.dense_1x(x)
|
|
|
|
|
| cp = self.change_probability(x)
|
|
|
| outputs.append(cp)
|
|
|
| if self.output_softmax:
|
| temp = outputs
|
| outputs = []
|
| for pred in temp:
|
| outputs.append(self.active(pred))
|
|
|
| return outputs
|
|
|
| class ChangeFormer_DE(nn.Module):
|
|
|
| def __init__(self, output_nc=2, decoder_softmax=False, embed_dim=256):
|
| super(ChangeFormer_DE, self).__init__()
|
|
|
| self.embed_dims = [64, 128, 320, 512]
|
| self.embedding_dim = embed_dim
|
|
|
|
|
| self.TDec_x2 = DecoderTransformer_v3(input_transform='multiple_select', in_index=[0, 1, 2, 3], align_corners=False,
|
| in_channels = self.embed_dims, embedding_dim= self.embedding_dim, output_nc=output_nc,
|
| decoder_softmax = decoder_softmax, feature_strides=[2, 4, 8, 16])
|
|
|
| def forward(self, f):
|
| fx1, fx2 = f[0], f[1]
|
| cp = self.TDec_x2(fx1, fx2)
|
| return cp[-1] |