| | |
| | |
| | |
| |
|
| | ''' |
| | NAFSSR: Stereo Image Super-Resolution Using NAFNet |
| | |
| | @InProceedings{Chu2022NAFSSR, |
| | author = {Xiaojie Chu and Liangyu Chen and Wenqing Yu}, |
| | title = {NAFSSR: Stereo Image Super-Resolution Using NAFNet}, |
| | booktitle = {CVPRW}, |
| | year = {2022}, |
| | } |
| | ''' |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from basicsr.models.archs.NAFNet_arch import LayerNorm2d, NAFBlock |
| | from basicsr.models.archs.arch_util import MySequential |
| | from basicsr.models.archs.local_arch import Local_Base |
| |
|
| | class SCAM(nn.Module): |
| | ''' |
| | Stereo Cross Attention Module (SCAM) |
| | ''' |
| | def __init__(self, c): |
| | super().__init__() |
| | self.scale = c ** -0.5 |
| |
|
| | self.norm_l = LayerNorm2d(c) |
| | self.norm_r = LayerNorm2d(c) |
| | self.l_proj1 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0) |
| | self.r_proj1 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0) |
| | |
| | self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) |
| | self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) |
| |
|
| | self.l_proj2 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0) |
| | self.r_proj2 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0) |
| |
|
| | def forward(self, x_l, x_r): |
| | Q_l = self.l_proj1(self.norm_l(x_l)).permute(0, 2, 3, 1) |
| | Q_r_T = self.r_proj1(self.norm_r(x_r)).permute(0, 2, 1, 3) |
| |
|
| | V_l = self.l_proj2(x_l).permute(0, 2, 3, 1) |
| | V_r = self.r_proj2(x_r).permute(0, 2, 3, 1) |
| |
|
| | |
| | attention = torch.matmul(Q_l, Q_r_T) * self.scale |
| |
|
| | F_r2l = torch.matmul(torch.softmax(attention, dim=-1), V_r) |
| | F_l2r = torch.matmul(torch.softmax(attention.permute(0, 1, 3, 2), dim=-1), V_l) |
| |
|
| | |
| | F_r2l = F_r2l.permute(0, 3, 1, 2) * self.beta |
| | F_l2r = F_l2r.permute(0, 3, 1, 2) * self.gamma |
| | return x_l + F_r2l, x_r + F_l2r |
| |
|
| | class DropPath(nn.Module): |
| | def __init__(self, drop_rate, module): |
| | super().__init__() |
| | self.drop_rate = drop_rate |
| | self.module = module |
| |
|
| | def forward(self, *feats): |
| | if self.training and np.random.rand() < self.drop_rate: |
| | return feats |
| |
|
| | new_feats = self.module(*feats) |
| | factor = 1. / (1 - self.drop_rate) if self.training else 1. |
| |
|
| | if self.training and factor != 1.: |
| | new_feats = tuple([x+factor*(new_x-x) for x, new_x in zip(feats, new_feats)]) |
| | return new_feats |
| |
|
| | class NAFBlockSR(nn.Module): |
| | ''' |
| | NAFBlock for Super-Resolution |
| | ''' |
| | def __init__(self, c, fusion=False, drop_out_rate=0.): |
| | super().__init__() |
| | self.blk = NAFBlock(c, drop_out_rate=drop_out_rate) |
| | self.fusion = SCAM(c) if fusion else None |
| |
|
| | def forward(self, *feats): |
| | feats = tuple([self.blk(x) for x in feats]) |
| | if self.fusion: |
| | feats = self.fusion(*feats) |
| | return feats |
| |
|
| | class NAFNetSR(nn.Module): |
| | ''' |
| | NAFNet for Super-Resolution |
| | ''' |
| | def __init__(self, up_scale=4, width=48, num_blks=16, img_channel=3, drop_path_rate=0., drop_out_rate=0., fusion_from=-1, fusion_to=-1, dual=False): |
| | super().__init__() |
| | self.dual = dual |
| | self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, |
| | bias=True) |
| | self.body = MySequential( |
| | *[DropPath( |
| | drop_path_rate, |
| | NAFBlockSR( |
| | width, |
| | fusion=(fusion_from <= i and i <= fusion_to), |
| | drop_out_rate=drop_out_rate |
| | )) for i in range(num_blks)] |
| | ) |
| |
|
| | self.up = nn.Sequential( |
| | nn.Conv2d(in_channels=width, out_channels=img_channel * up_scale**2, kernel_size=3, padding=1, stride=1, groups=1, bias=True), |
| | nn.PixelShuffle(up_scale) |
| | ) |
| | self.up_scale = up_scale |
| |
|
| | def forward(self, inp): |
| | inp_hr = F.interpolate(inp, scale_factor=self.up_scale, mode='bilinear') |
| | if self.dual: |
| | inp = inp.chunk(2, dim=1) |
| | else: |
| | inp = (inp, ) |
| | feats = [self.intro(x) for x in inp] |
| | feats = self.body(*feats) |
| | out = torch.cat([self.up(x) for x in feats], dim=1) |
| | out = out + inp_hr |
| | return out |
| |
|
| | class NAFSSR(Local_Base, NAFNetSR): |
| | def __init__(self, *args, train_size=(1, 6, 30, 90), fast_imp=False, fusion_from=-1, fusion_to=1000, **kwargs): |
| | Local_Base.__init__(self) |
| | NAFNetSR.__init__(self, *args, img_channel=3, fusion_from=fusion_from, fusion_to=fusion_to, dual=True, **kwargs) |
| |
|
| | N, C, H, W = train_size |
| | base_size = (int(H * 1.5), int(W * 1.5)) |
| |
|
| | self.eval() |
| | with torch.no_grad(): |
| | self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp) |
| |
|
| | if __name__ == '__main__': |
| | num_blks = 128 |
| | width = 128 |
| | droppath=0.1 |
| | train_size = (1, 6, 30, 90) |
| |
|
| | net = NAFSSR(up_scale=2,train_size=train_size, fast_imp=True, width=width, num_blks=num_blks, drop_path_rate=droppath) |
| |
|
| | inp_shape = (6, 64, 64) |
| |
|
| | from ptflops import get_model_complexity_info |
| | FLOPS = 0 |
| | macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=True) |
| |
|
| | |
| | print(params) |
| | macs = float(macs[:-4]) + FLOPS / 10 ** 9 |
| |
|
| | print('mac', macs, params) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| |
|
| |
|
| |
|
| |
|