| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .GFM import GFM_Module |
| from .DGG import DGG_Module |
| from .ISF import ISF_Module |
|
|
|
|
| class MLP(nn.Module): |
| """Simple MLP for decoder""" |
| def __init__(self, input_dim, embed_dim): |
| super().__init__() |
| self.proj = nn.Linear(input_dim, embed_dim) |
|
|
| def forward(self, x): |
| x = x.flatten(2).transpose(1, 2) |
| x = self.proj(x) |
| return x |
|
|
|
|
| class HiF_Decoder(nn.Module): |
| """Hierarchical Factorized Decoder""" |
| def __init__( |
| self, |
| encoder_channels=[64, 128, 320, 512], |
| decoder_channels=256, |
| ): |
| super().__init__() |
| |
| |
| self.linear_c4 = MLP(input_dim=encoder_channels[3], embed_dim=decoder_channels) |
| self.linear_c3 = MLP(input_dim=encoder_channels[2], embed_dim=decoder_channels) |
| self.linear_c2 = MLP(input_dim=encoder_channels[1], embed_dim=decoder_channels) |
| self.linear_c1 = MLP(input_dim=encoder_channels[0], embed_dim=decoder_channels) |
|
|
| self.dropout = nn.Dropout2d(0.1) |
|
|
| self.gfm_c4_1 = GFM_Module(decoder_channels, decoder_channels//2) |
| self.gfm_c3_1 = GFM_Module(decoder_channels, decoder_channels//2) |
| self.gfm_c2_1 = GFM_Module(decoder_channels, decoder_channels//2) |
| self.gfm_c1_1 = GFM_Module(decoder_channels, decoder_channels//2) |
|
|
| self.gfm_c_o_1 = GFM_Module(decoder_channels, decoder_channels//2) |
| self.gfm_c_e_1 = GFM_Module(decoder_channels, decoder_channels//2) |
|
|
| self.gfm_c_o_2 = GFM_Module(decoder_channels//2, decoder_channels//4) |
| self.gfm_c_e_2 = GFM_Module(decoder_channels//2, decoder_channels//4) |
|
|
| self.gfm_c_o_3 = GFM_Module(decoder_channels//4, decoder_channels//8) |
| self.gfm_c_e_3 = GFM_Module(decoder_channels//4, decoder_channels//8) |
|
|
| self.cyclic_shuffle_enhancer_o = ISF_Module(channels=decoder_channels, groups=4, kernel_size=3, cyclic_percent=0.0) |
| self.cyclic_shuffle_enhancer_e = ISF_Module(channels=decoder_channels, groups=4, kernel_size=3, cyclic_percent=0.0) |
|
|
| self.gatefuser = DGG_Module(channels=decoder_channels//4, groups=4) |
|
|
| def forward(self, encoder_features): |
| |
| c1, c2, c3, c4 = encoder_features |
|
|
| |
| n, _, h, w = c1.shape |
|
|
| |
| _c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3]) |
| _c4 = F.interpolate(_c4, size=(h, w), mode='bilinear', align_corners=False) |
|
|
| _c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3]) |
| _c3 = F.interpolate(_c3, size=(h, w), mode='bilinear', align_corners=False) |
|
|
| _c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3]) |
| _c2 = F.interpolate(_c2, size=(h, w), mode='bilinear', align_corners=False) |
|
|
| _c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3]) |
| |
|
|
| |
| |
|
|
| |
| |
| _c4_g1_o, _c4_g2_e = self.gfm_c4_1(_c4) |
| _c3_g1_o, _c3_g2_e = self.gfm_c3_1(_c3) |
| _c2_g1_o, _c2_g2_e = self.gfm_c2_1(_c2) |
| _c1_g1_o, _c1_g2_e = self.gfm_c1_1(_c1) |
| |
| _c_o_1 = torch.cat([_c4_g1_o, _c3_g1_o, _c2_g1_o, _c1_g1_o], dim=1) |
| _c_e_1 = torch.cat([_c4_g2_e, _c3_g2_e, _c2_g2_e, _c1_g2_e], dim=1) |
| _c_o_1_f = self.cyclic_shuffle_enhancer_o(_c_o_1) |
| _c_e_1_f = self.cyclic_shuffle_enhancer_e(_c_e_1) |
|
|
| _c_o_1_o, _c_o_1_e = self.gfm_c_o_1(_c_o_1_f) |
| _c_e_1_o, _c_e_1_e = self.gfm_c_e_1(_c_e_1_f) |
| |
| |
| |
| _c_o_2 = torch.cat([_c_o_1_o, _c_e_1_o], dim=1) |
| _c_e_2 = torch.cat([_c_o_1_e, _c_e_1_e], dim=1) |
| _c_o_2_o, _c_o_2_e = self.gfm_c_o_2(_c_o_2) |
| _c_e_2_o, _c_e_2_e = self.gfm_c_e_2(_c_e_2) |
| |
| |
| |
| _c_o_3 = torch.cat([_c_o_2_o, _c_e_2_o], dim=1) |
| _c_e_3 = torch.cat([_c_o_2_e, _c_e_2_e], dim=1) |
| _c_o_3_o, _c_o_3_e = self.gfm_c_o_3(_c_o_3) |
| _c_e_3_o, _c_e_3_e = self.gfm_c_e_3(_c_e_3) |
|
|
| x = torch.cat([_c_o_3_o, _c_e_3_o, _c_o_3_e, _c_e_3_e], dim=1) |
| x_f = self.gatefuser(x) |
| x = x + x_f |
| x = self.dropout(x) |
| return x |
|
|