| | import numpy as np |
| | import fvcore.nn.weight_init as weight_init |
| | import torch |
| | from torch import nn |
| | from torch.nn import functional as F |
| |
|
| |
|
| | from rscd.models.decoderheads.transformer_decoder.position_encoding import PositionEmbeddingSine |
| | from rscd.models.decoderheads.transformer_decoder.transformer import _get_clones, _get_activation_fn |
| | from rscd.models.decoderheads.pixel_decoder.ops.modules import MSDeformAttn |
| |
|
| | |
| | class MSDeformAttnTransformerEncoderLayer(nn.Module): |
| | def __init__(self, |
| | d_model=256, d_ffn=1024, |
| | dropout=0.1, activation="relu", |
| | n_levels=4, n_heads=8, n_points=4): |
| | super().__init__() |
| |
|
| | |
| | self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) |
| | self.dropout1 = nn.Dropout(dropout) |
| | self.norm1 = nn.LayerNorm(d_model) |
| |
|
| | |
| | self.linear1 = nn.Linear(d_model, d_ffn) |
| | self.activation = _get_activation_fn(activation) |
| | self.dropout2 = nn.Dropout(dropout) |
| | self.linear2 = nn.Linear(d_ffn, d_model) |
| | self.dropout3 = nn.Dropout(dropout) |
| | self.norm2 = nn.LayerNorm(d_model) |
| |
|
| | @staticmethod |
| | def with_pos_embed(tensor, pos): |
| | return tensor if pos is None else tensor + pos |
| |
|
| | def forward_ffn(self, src): |
| | src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) |
| | src = src + self.dropout3(src2) |
| | src = self.norm2(src) |
| | return src |
| |
|
| | def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None): |
| | |
| | src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask) |
| | src = src + self.dropout1(src2) |
| | src = self.norm1(src) |
| |
|
| | |
| | src = self.forward_ffn(src) |
| |
|
| | return src |
| |
|
| |
|
| | class MSDeformAttnTransformerEncoder(nn.Module): |
| | def __init__(self, encoder_layer, num_layers): |
| | super().__init__() |
| | self.layers = _get_clones(encoder_layer, num_layers) |
| | self.num_layers = num_layers |
| |
|
| | @staticmethod |
| | def get_reference_points(spatial_shapes, valid_ratios, device): |
| | reference_points_list = [] |
| | for lvl, (H_, W_) in enumerate(spatial_shapes): |
| | ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), |
| | torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) |
| | ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) |
| | ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) |
| | ref = torch.stack((ref_x, ref_y), -1) |
| | reference_points_list.append(ref) |
| | reference_points = torch.cat(reference_points_list, 1) |
| | reference_points = reference_points[:, :, None] * valid_ratios[:, None] |
| | return reference_points |
| |
|
| | def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None): |
| | output = src |
| | reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device) |
| | for _, layer in enumerate(self.layers): |
| | output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask) |
| |
|
| | return output |
| |
|
| |
|
| | class MSDeformAttnTransformerEncoderOnly(nn.Module): |
| | def __init__(self, d_model=256, nhead=8, |
| | num_encoder_layers=6, dim_feedforward=1024, dropout=0.1, |
| | activation="relu", |
| | num_feature_levels=4, enc_n_points=4, |
| | ): |
| | super().__init__() |
| |
|
| | self.d_model = d_model |
| | self.nhead = nhead |
| |
|
| | encoder_layer = MSDeformAttnTransformerEncoderLayer(d_model, dim_feedforward, |
| | dropout, activation, |
| | num_feature_levels, nhead, enc_n_points) |
| | self.encoder = MSDeformAttnTransformerEncoder(encoder_layer, num_encoder_layers) |
| |
|
| | self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) |
| |
|
| | self._reset_parameters() |
| |
|
| | def _reset_parameters(self): |
| | for p in self.parameters(): |
| | if p.dim() > 1: |
| | nn.init.xavier_uniform_(p) |
| | for m in self.modules(): |
| | if isinstance(m, MSDeformAttn): |
| | m._reset_parameters() |
| | nn.init.normal_(self.level_embed) |
| |
|
| | def get_valid_ratio(self, mask): |
| | _, H, W = mask.shape |
| | valid_H = torch.sum(~mask[:, :, 0], 1) |
| | valid_W = torch.sum(~mask[:, 0, :], 1) |
| | valid_ratio_h = valid_H.float() / H |
| | valid_ratio_w = valid_W.float() / W |
| | valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) |
| | return valid_ratio |
| |
|
| | def forward(self, srcs, pos_embeds): |
| | masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in srcs] |
| | |
| | src_flatten = [] |
| | mask_flatten = [] |
| | lvl_pos_embed_flatten = [] |
| | spatial_shapes = [] |
| | for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): |
| | bs, c, h, w = src.shape |
| | spatial_shape = (h, w) |
| | spatial_shapes.append(spatial_shape) |
| | src = src.flatten(2).transpose(1, 2) |
| | mask = mask.flatten(1) |
| | pos_embed = pos_embed.flatten(2).transpose(1, 2) |
| | lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) |
| | lvl_pos_embed_flatten.append(lvl_pos_embed) |
| | src_flatten.append(src) |
| | mask_flatten.append(mask) |
| | src_flatten = torch.cat(src_flatten, 1) |
| | mask_flatten = torch.cat(mask_flatten, 1) |
| | lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) |
| | spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) |
| | level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) |
| | valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) |
| |
|
| | |
| | memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten) |
| |
|
| | return memory, spatial_shapes, level_start_index |
| |
|
| | class MSDeformAttnPixelDecoder(nn.Module): |
| | def __init__( |
| | self, |
| | input_shape, |
| | transformer_dropout=0.1, |
| | transformer_nheads=8, |
| | transformer_dim_feedforward=2048, |
| | transformer_enc_layers=6, |
| | conv_dim=256, |
| | mask_dim=256, |
| | |
| | |
| | transformer_in_features= ["res3", "res4", "res5"], |
| | common_stride=4, |
| | ): |
| | super().__init__() |
| | |
| | transformer_input_shape = {k: v for k, v in input_shape.items() if k in transformer_in_features} |
| | |
| | |
| | self.in_features = [k for k, v in input_shape.items()] |
| | self.feature_channels = [v.channel for k, v in input_shape.items()] |
| | |
| | |
| | self.transformer_in_features = [k for k, v in transformer_input_shape.items()] |
| | transformer_in_channels = [v.channel for k, v in transformer_input_shape.items()] |
| | self.transformer_feature_strides = [v.stride for k, v in transformer_input_shape.items()] |
| |
|
| | self.transformer_num_feature_levels = len(self.transformer_in_features) |
| | if self.transformer_num_feature_levels > 1: |
| | input_proj_list = [] |
| | |
| | for in_channels in transformer_in_channels[::-1]: |
| | input_proj_list.append(nn.Sequential( |
| | nn.Conv2d(in_channels, conv_dim, kernel_size=1), |
| | nn.GroupNorm(32, conv_dim), |
| | )) |
| | self.input_proj = nn.ModuleList(input_proj_list) |
| | else: |
| | self.input_proj = nn.ModuleList([ |
| | nn.Sequential( |
| | nn.Conv2d(transformer_in_channels[-1], conv_dim, kernel_size=1), |
| | nn.GroupNorm(32, conv_dim), |
| | )]) |
| |
|
| | for proj in self.input_proj: |
| | nn.init.xavier_uniform_(proj[0].weight, gain=1) |
| | nn.init.constant_(proj[0].bias, 0) |
| |
|
| | self.transformer = MSDeformAttnTransformerEncoderOnly( |
| | d_model=conv_dim, |
| | dropout=transformer_dropout, |
| | nhead=transformer_nheads, |
| | dim_feedforward=transformer_dim_feedforward, |
| | num_encoder_layers=transformer_enc_layers, |
| | num_feature_levels=self.transformer_num_feature_levels, |
| | ) |
| | N_steps = conv_dim // 2 |
| | self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) |
| |
|
| | self.mask_dim = mask_dim |
| | |
| | self.mask_features = nn.Conv2d( |
| | conv_dim, |
| | mask_dim, |
| | kernel_size=1, |
| | stride=1, |
| | padding=0, |
| | ) |
| | weight_init.c2_xavier_fill(self.mask_features) |
| | |
| | self.maskformer_num_feature_levels = 3 |
| | self.common_stride = common_stride |
| |
|
| | |
| | stride = min(self.transformer_feature_strides) |
| | self.num_fpn_levels = int(np.log2(stride) - np.log2(self.common_stride)) |
| |
|
| | lateral_convs = [] |
| | output_convs = [] |
| |
|
| | for idx, in_channels in enumerate(self.feature_channels[:self.num_fpn_levels]): |
| | lateral_conv = nn.Sequential(nn.Conv2d(in_channels, conv_dim, kernel_size=1), |
| | nn.GroupNorm(32, conv_dim), |
| | nn.ReLU(inplace=True)) |
| |
|
| | output_conv = nn.Sequential(nn.Conv2d(conv_dim, conv_dim, kernel_size=3, stride=1, padding=1), |
| | nn.GroupNorm(32, conv_dim), |
| | nn.ReLU(inplace=True)) |
| | |
| | weight_init.c2_xavier_fill(lateral_conv[0]) |
| | weight_init.c2_xavier_fill(output_conv[0]) |
| | self.add_module("adapter_{}".format(idx + 1), lateral_conv) |
| | self.add_module("layer_{}".format(idx + 1), output_conv) |
| |
|
| | lateral_convs.append(lateral_conv) |
| | output_convs.append(output_conv) |
| | |
| | |
| | self.lateral_convs = lateral_convs[::-1] |
| | self.output_convs = output_convs[::-1] |
| |
|
| | def forward_features(self, features): |
| | srcs = [] |
| | pos = [] |
| | |
| | for idx, f in enumerate(self.transformer_in_features[::-1]): |
| | x = features[f].float() |
| | srcs.append(self.input_proj[idx](x)) |
| | pos.append(self.pe_layer(x)) |
| |
|
| | y, spatial_shapes, level_start_index = self.transformer(srcs, pos) |
| | bs = y.shape[0] |
| |
|
| | split_size_or_sections = [None] * self.transformer_num_feature_levels |
| | for i in range(self.transformer_num_feature_levels): |
| | if i < self.transformer_num_feature_levels - 1: |
| | split_size_or_sections[i] = level_start_index[i + 1] - level_start_index[i] |
| | else: |
| | split_size_or_sections[i] = y.shape[1] - level_start_index[i] |
| | y = torch.split(y, split_size_or_sections, dim=1) |
| |
|
| | out = [] |
| | multi_scale_features = [] |
| | num_cur_levels = 0 |
| | for i, z in enumerate(y): |
| | out.append(z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0], spatial_shapes[i][1])) |
| |
|
| | |
| | |
| | for idx, f in enumerate(self.in_features[:self.num_fpn_levels][::-1]): |
| | x = features[f].float() |
| | lateral_conv = self.lateral_convs[idx] |
| | output_conv = self.output_convs[idx] |
| | cur_fpn = lateral_conv(x) |
| | |
| | y = cur_fpn + F.interpolate(out[-1], size=cur_fpn.shape[-2:], mode="bilinear", align_corners=False) |
| | y = output_conv(y) |
| | out.append(y) |
| |
|
| | for o in out: |
| | if num_cur_levels < self.maskformer_num_feature_levels: |
| | multi_scale_features.append(o) |
| | num_cur_levels += 1 |
| |
|
| | return self.mask_features(out[-1]), out[0], multi_scale_features |
| |
|
| | class MSDeformAttnPixelDecoder4ScalesFASeg(nn.Module): |
| | def __init__( |
| | self, |
| | input_shape, |
| | transformer_dropout= 0.1, |
| | transformer_nheads= 8, |
| | transformer_dim_feedforward= 2048, |
| | transformer_enc_layers= 6, |
| | conv_dim= 256, |
| | mask_dim= 256, |
| | |
| | |
| | transformer_in_features= ["res3", "res4", "res5"], |
| | common_stride= 4, |
| | ): |
| | """ |
| | NOTE: this interface is experimental. |
| | Args: |
| | input_shape: shapes (channels and stride) of the input features |
| | transformer_dropout: dropout probability in transformer |
| | transformer_nheads: number of heads in transformer |
| | transformer_dim_feedforward: dimension of feedforward network |
| | transformer_enc_layers: number of transformer encoder layers |
| | conv_dims: number of output channels for the intermediate conv layers. |
| | mask_dim: number of output channels for the final conv layer. |
| | norm (str or callable): normalization for all conv layers |
| | """ |
| | super().__init__() |
| | transformer_input_shape = { |
| | k: v for k, v in input_shape.items() if k in transformer_in_features |
| | } |
| |
|
| | |
| | self.in_features = [k for k, v in input_shape.items()] |
| | self.feature_channels = [v.channel for k, v in input_shape.items()] |
| | |
| | |
| | self.transformer_in_features = [k for k, v in transformer_input_shape.items()] |
| | transformer_in_channels = [v.channel for k, v in transformer_input_shape.items()] |
| | self.transformer_feature_strides = [v.stride for k, v in transformer_input_shape.items()] |
| |
|
| | self.transformer_num_feature_levels = len(self.transformer_in_features) |
| | |
| |
|
| | pos_linear_list = [] |
| | if self.transformer_num_feature_levels > 1: |
| | input_proj_list = [] |
| | |
| | for in_channels in transformer_in_channels[::-1]: |
| | input_proj_list.append(nn.Sequential( |
| | nn.Conv2d(in_channels, conv_dim, kernel_size=1), |
| | nn.GroupNorm(32, conv_dim), |
| | )) |
| |
|
| | |
| | pos_linear_list.append(nn.Conv2d(256, 256, kernel_size=3, stride=1, |
| | padding=1, bias=True, groups=256)) |
| |
|
| | |
| | input_proj_list.append(nn.Sequential( |
| | nn.Conv2d(in_channels // 2, conv_dim, kernel_size=1), |
| | nn.GroupNorm(32, conv_dim), |
| | )) |
| | self.input_proj = nn.ModuleList(input_proj_list) |
| |
|
| | else: |
| | self.input_proj = nn.ModuleList([ |
| | nn.Sequential( |
| | nn.Conv2d(transformer_in_channels[-1], conv_dim, kernel_size=1), |
| | nn.GroupNorm(32, conv_dim), |
| | )]) |
| |
|
| | |
| | pos_linear_list.append(nn.Conv2d(256, 256, kernel_size=3, stride=1, |
| | padding=1, bias=True, groups=256)) |
| |
|
| | self.pos_linear = nn.ModuleList(pos_linear_list) |
| |
|
| | for proj in self.input_proj: |
| | nn.init.xavier_uniform_(proj[0].weight, gain=1) |
| | nn.init.constant_(proj[0].bias, 0) |
| |
|
| | self.transformer = MSDeformAttnTransformerEncoderOnly( |
| | d_model=conv_dim, |
| | dropout=transformer_dropout, |
| | nhead=transformer_nheads, |
| | dim_feedforward=transformer_dim_feedforward, |
| | num_encoder_layers=transformer_enc_layers, |
| | num_feature_levels=self.transformer_num_feature_levels, |
| | ) |
| | N_steps = conv_dim // 2 |
| |
|
| | self.mask_dim = mask_dim |
| | |
| | self.mask_features = nn.Conv2d( |
| | conv_dim, |
| | mask_dim, |
| | kernel_size=1, |
| | stride=1, |
| | padding=0, |
| | ) |
| | weight_init.c2_xavier_fill(self.mask_features) |
| |
|
| | self.maskformer_num_feature_levels = 4 |
| | self.common_stride = common_stride |
| |
|
| | |
| | stride = min(self.transformer_feature_strides) |
| | self.num_fpn_levels = int(np.log2(stride) - np.log2(self.common_stride)) |
| |
|
| | lateral_convs = [] |
| | output_convs = [] |
| |
|
| | for idx, in_channels in enumerate(self.feature_channels[:self.num_fpn_levels]): |
| | lateral_conv = nn.Sequential(nn.Conv2d(in_channels, conv_dim, kernel_size=1), |
| | nn.GroupNorm(32, conv_dim), |
| | nn.ReLU(inplace=True)) |
| |
|
| | output_conv = nn.Sequential(nn.Conv2d(conv_dim, conv_dim, kernel_size=3, stride=1, padding=1), |
| | nn.GroupNorm(32, conv_dim), |
| | nn.ReLU(inplace=True)) |
| | |
| | weight_init.c2_xavier_fill(lateral_conv[0]) |
| | weight_init.c2_xavier_fill(output_conv[0]) |
| | self.add_module("adapter_{}".format(idx + 1), lateral_conv) |
| | self.add_module("layer_{}".format(idx + 1), output_conv) |
| |
|
| | lateral_convs.append(lateral_conv) |
| | output_convs.append(output_conv) |
| | |
| | |
| | self.lateral_convs = lateral_convs[::-1] |
| | self.output_convs = output_convs[::-1] |
| |
|
| | def forward_features(self, features): |
| | srcs = [] |
| | pos = [] |
| |
|
| | |
| | for idx, f in enumerate(self.transformer_in_features[::-1]): |
| | x = features[f].float() |
| | x = self.input_proj[idx](x) |
| | |
| | srcs.append(x) |
| | pos.append(self.pos_linear[idx](x)) |
| |
|
| | x_res2 = self.input_proj[-1](features['res2']) |
| | pos.append(self.pos_linear[-1](x_res2)) |
| | srcs.append(x_res2) |
| |
|
| | y, spatial_shapes, level_start_index = self.transformer(srcs[:3], pos[:3]) |
| | bs = y.shape[0] |
| |
|
| | split_size_or_sections = [None] * self.transformer_num_feature_levels |
| | for i in range(self.transformer_num_feature_levels): |
| | if i < self.transformer_num_feature_levels - 1: |
| | split_size_or_sections[i] = level_start_index[i + 1] - level_start_index[i] |
| | else: |
| | split_size_or_sections[i] = y.shape[1] - level_start_index[i] |
| | y = torch.split(y, split_size_or_sections, dim=1) |
| |
|
| | out = [] |
| | multi_scale_features = [] |
| | num_cur_levels = 0 |
| | for i, z in enumerate(y): |
| | out.append(z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0], spatial_shapes[i][1])) |
| |
|
| | |
| | |
| | for idx, f in enumerate(self.in_features[:self.num_fpn_levels][::-1]): |
| | x = features[f].float() |
| | lateral_conv = self.lateral_convs[idx] |
| | output_conv = self.output_convs[idx] |
| | cur_fpn = lateral_conv(x) |
| | |
| | y = cur_fpn + F.interpolate(out[-1], size=cur_fpn.shape[-2:], mode="bilinear", align_corners=False) |
| | y = output_conv(y) |
| | out.append(y) |
| |
|
| | for o in out: |
| | if num_cur_levels < self.maskformer_num_feature_levels: |
| | multi_scale_features.append(o) |
| | num_cur_levels += 1 |
| |
|
| | return self.mask_features(out[-1]), out[0], multi_scale_features, pos |