| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | import warnings
|
| |
|
| |
|
| | class MLP(nn.Module):
|
| | """
|
| | Linear Embedding
|
| | """
|
| | def __init__(self, input_dim=2048, embed_dim=768):
|
| | super().__init__()
|
| | self.proj = nn.Linear(input_dim, embed_dim)
|
| |
|
| | def forward(self, x):
|
| | x = x.flatten(2).transpose(1, 2)
|
| | x = self.proj(x)
|
| | return x
|
| |
|
| | class UpsampleConvLayer(torch.nn.Module):
|
| | def __init__(self, in_channels, out_channels, kernel_size, stride):
|
| | super(UpsampleConvLayer, self).__init__()
|
| | self.conv2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=1)
|
| |
|
| | def forward(self, x):
|
| | out = self.conv2d(x)
|
| | return out
|
| |
|
| | class ResidualBlock(torch.nn.Module):
|
| | def __init__(self, channels):
|
| | super(ResidualBlock, self).__init__()
|
| | self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1, padding=1)
|
| | self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1, padding=1)
|
| | self.relu = nn.ReLU()
|
| |
|
| | def forward(self, x):
|
| | residual = x
|
| | out = self.relu(self.conv1(x))
|
| | out = self.conv2(out) * 0.1
|
| | out = torch.add(out, residual)
|
| | return out
|
| |
|
| | class ConvLayer(nn.Module):
|
| | def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
|
| | super(ConvLayer, self).__init__()
|
| |
|
| |
|
| | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
|
| |
|
| | def forward(self, x):
|
| |
|
| | out = self.conv2d(x)
|
| | return out
|
| |
|
| |
|
| |
|
| | def conv_diff(in_channels, out_channels):
|
| | return nn.Sequential(
|
| | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
| | nn.ReLU(),
|
| | nn.BatchNorm2d(out_channels),
|
| | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
| | nn.ReLU()
|
| | )
|
| |
|
| |
|
| | def make_prediction(in_channels, out_channels):
|
| | return nn.Sequential(
|
| | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
| | nn.ReLU(),
|
| | nn.BatchNorm2d(out_channels),
|
| | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
| | )
|
| |
|
| | def resize(input,
|
| | size=None,
|
| | scale_factor=None,
|
| | mode='nearest',
|
| | align_corners=None,
|
| | warning=True):
|
| | if warning:
|
| | if size is not None and align_corners:
|
| | input_h, input_w = tuple(int(x) for x in input.shape[2:])
|
| | output_h, output_w = tuple(int(x) for x in size)
|
| | if output_h > input_h or output_w > output_h:
|
| | if ((output_h > 1 and output_w > 1 and input_h > 1
|
| | and input_w > 1) and (output_h - 1) % (input_h - 1)
|
| | and (output_w - 1) % (input_w - 1)):
|
| | warnings.warn(
|
| | f'When align_corners={align_corners}, '
|
| | 'the output would more aligned if '
|
| | f'input size {(input_h, input_w)} is `x+1` and '
|
| | f'out size {(output_h, output_w)} is `nx+1`')
|
| | return F.interpolate(input, size, scale_factor, mode, align_corners)
|
| |
|
| |
|
| | class DecoderTransformer_v3(nn.Module):
|
| | """
|
| | Transformer Decoder
|
| | """
|
| | def __init__(self, input_transform='multiple_select', in_index=[0, 1, 2, 3], align_corners=True,
|
| | in_channels = [32, 64, 128, 256], embedding_dim= 64, output_nc=2,
|
| | decoder_softmax = False, feature_strides=[2, 4, 8, 16]):
|
| | super(DecoderTransformer_v3, self).__init__()
|
| |
|
| | assert len(feature_strides) == len(in_channels)
|
| | assert min(feature_strides) == feature_strides[0]
|
| |
|
| |
|
| | self.feature_strides = feature_strides
|
| | self.input_transform = input_transform
|
| | self.in_index = in_index
|
| | self.align_corners = align_corners
|
| | self.in_channels = in_channels
|
| | self.embedding_dim = embedding_dim
|
| | self.output_nc = output_nc
|
| | c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
|
| |
|
| |
|
| | self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=self.embedding_dim)
|
| | self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=self.embedding_dim)
|
| | self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=self.embedding_dim)
|
| | self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=self.embedding_dim)
|
| |
|
| |
|
| | self.diff_c4 = conv_diff(in_channels=2*self.embedding_dim, out_channels=self.embedding_dim)
|
| | self.diff_c3 = conv_diff(in_channels=2*self.embedding_dim, out_channels=self.embedding_dim)
|
| | self.diff_c2 = conv_diff(in_channels=2*self.embedding_dim, out_channels=self.embedding_dim)
|
| | self.diff_c1 = conv_diff(in_channels=2*self.embedding_dim, out_channels=self.embedding_dim)
|
| |
|
| |
|
| | self.make_pred_c4 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc)
|
| | self.make_pred_c3 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc)
|
| | self.make_pred_c2 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc)
|
| | self.make_pred_c1 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc)
|
| |
|
| |
|
| | self.linear_fuse = nn.Sequential(
|
| | nn.Conv2d( in_channels=self.embedding_dim*len(in_channels), out_channels=self.embedding_dim,
|
| | kernel_size=1),
|
| | nn.BatchNorm2d(self.embedding_dim)
|
| | )
|
| |
|
| |
|
| | self.convd2x = UpsampleConvLayer(self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2)
|
| | self.dense_2x = nn.Sequential( ResidualBlock(self.embedding_dim))
|
| | self.convd1x = UpsampleConvLayer(self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2)
|
| | self.dense_1x = nn.Sequential( ResidualBlock(self.embedding_dim))
|
| | self.change_probability = ConvLayer(self.embedding_dim, self.output_nc, kernel_size=3, stride=1, padding=1)
|
| |
|
| |
|
| | self.output_softmax = decoder_softmax
|
| | self.active = nn.Sigmoid()
|
| |
|
| | def _transform_inputs(self, inputs):
|
| | """Transform inputs for decoder.
|
| | Args:
|
| | inputs (list[Tensor]): List of multi-level img features.
|
| | Returns:
|
| | Tensor: The transformed inputs
|
| | """
|
| |
|
| | if self.input_transform == 'resize_concat':
|
| | inputs = [inputs[i] for i in self.in_index]
|
| | upsampled_inputs = [
|
| | resize(
|
| | input=x,
|
| | size=inputs[0].shape[2:],
|
| | mode='bilinear',
|
| | align_corners=self.align_corners) for x in inputs
|
| | ]
|
| | inputs = torch.cat(upsampled_inputs, dim=1)
|
| | elif self.input_transform == 'multiple_select':
|
| | inputs = [inputs[i] for i in self.in_index]
|
| | else:
|
| | inputs = inputs[self.in_index]
|
| |
|
| | return inputs
|
| |
|
| | def forward(self, inputs1, inputs2):
|
| |
|
| | x_1 = self._transform_inputs(inputs1)
|
| | x_2 = self._transform_inputs(inputs2)
|
| |
|
| |
|
| | c1_1, c2_1, c3_1, c4_1 = x_1
|
| | c1_2, c2_2, c3_2, c4_2 = x_2
|
| |
|
| |
|
| | n, _, h, w = c4_1.shape
|
| |
|
| | outputs = []
|
| |
|
| | _c4_1 = self.linear_c4(c4_1).permute(0,2,1).reshape(n, -1, c4_1.shape[2], c4_1.shape[3])
|
| | _c4_2 = self.linear_c4(c4_2).permute(0,2,1).reshape(n, -1, c4_2.shape[2], c4_2.shape[3])
|
| | _c4 = self.diff_c4(torch.cat((_c4_1, _c4_2), dim=1))
|
| | p_c4 = self.make_pred_c4(_c4)
|
| | outputs.append(p_c4)
|
| | _c4_up= resize(_c4, size=c1_2.size()[2:], mode='bilinear', align_corners=False)
|
| |
|
| |
|
| | _c3_1 = self.linear_c3(c3_1).permute(0,2,1).reshape(n, -1, c3_1.shape[2], c3_1.shape[3])
|
| | _c3_2 = self.linear_c3(c3_2).permute(0,2,1).reshape(n, -1, c3_2.shape[2], c3_2.shape[3])
|
| | _c3 = self.diff_c3(torch.cat((_c3_1, _c3_2), dim=1)) + F.interpolate(_c4, scale_factor=2, mode="bilinear")
|
| | p_c3 = self.make_pred_c3(_c3)
|
| | outputs.append(p_c3)
|
| | _c3_up= resize(_c3, size=c1_2.size()[2:], mode='bilinear', align_corners=False)
|
| |
|
| |
|
| | _c2_1 = self.linear_c2(c2_1).permute(0,2,1).reshape(n, -1, c2_1.shape[2], c2_1.shape[3])
|
| | _c2_2 = self.linear_c2(c2_2).permute(0,2,1).reshape(n, -1, c2_2.shape[2], c2_2.shape[3])
|
| | _c2 = self.diff_c2(torch.cat((_c2_1, _c2_2), dim=1)) + F.interpolate(_c3, scale_factor=2, mode="bilinear")
|
| | p_c2 = self.make_pred_c2(_c2)
|
| | outputs.append(p_c2)
|
| | _c2_up= resize(_c2, size=c1_2.size()[2:], mode='bilinear', align_corners=False)
|
| |
|
| |
|
| | _c1_1 = self.linear_c1(c1_1).permute(0,2,1).reshape(n, -1, c1_1.shape[2], c1_1.shape[3])
|
| | _c1_2 = self.linear_c1(c1_2).permute(0,2,1).reshape(n, -1, c1_2.shape[2], c1_2.shape[3])
|
| | _c1 = self.diff_c1(torch.cat((_c1_1, _c1_2), dim=1)) + F.interpolate(_c2, scale_factor=2, mode="bilinear")
|
| | p_c1 = self.make_pred_c1(_c1)
|
| | outputs.append(p_c1)
|
| |
|
| |
|
| | _c = self.linear_fuse(torch.cat((_c4_up, _c3_up, _c2_up, _c1), dim=1))
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | x = self.convd2x(_c)
|
| |
|
| | x = self.dense_2x(x)
|
| |
|
| | x = self.convd1x(x)
|
| |
|
| | x = self.dense_1x(x)
|
| |
|
| |
|
| | cp = self.change_probability(x)
|
| |
|
| | outputs.append(cp)
|
| |
|
| | if self.output_softmax:
|
| | temp = outputs
|
| | outputs = []
|
| | for pred in temp:
|
| | outputs.append(self.active(pred))
|
| |
|
| | return outputs
|
| |
|
| | class ChangeFormer_DE(nn.Module):
|
| |
|
| | def __init__(self, output_nc=2, decoder_softmax=False, embed_dim=256):
|
| | super(ChangeFormer_DE, self).__init__()
|
| |
|
| | self.embed_dims = [64, 128, 320, 512]
|
| | self.embedding_dim = embed_dim
|
| |
|
| |
|
| | self.TDec_x2 = DecoderTransformer_v3(input_transform='multiple_select', in_index=[0, 1, 2, 3], align_corners=False,
|
| | in_channels = self.embed_dims, embedding_dim= self.embedding_dim, output_nc=output_nc,
|
| | decoder_softmax = decoder_softmax, feature_strides=[2, 4, 8, 16])
|
| |
|
| | def forward(self, f):
|
| | fx1, fx2 = f[0], f[1]
|
| | cp = self.TDec_x2(fx1, fx2)
|
| | return cp[-1] |