| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from .layers import * |
| | class PAA_d(nn.Module): |
| | def __init__(self, in_channel, out_channel=1, depth=64, base_size=None, stage=None): |
| | super(PAA_d, self).__init__() |
| | self.conv1 = Conv2d(in_channel ,depth, 3) |
| | self.conv2 = Conv2d(depth, depth, 3) |
| | self.conv3 = Conv2d(depth, depth, 3) |
| | self.conv4 = Conv2d(depth, depth, 3) |
| | self.conv5 = Conv2d(depth, out_channel, 3, bn=False) |
| | |
| | self.base_size = base_size |
| | self.stage = stage |
| | |
| | if base_size is not None and stage is not None: |
| | self.stage_size = (base_size[0] // (2 ** stage), base_size[1] // (2 ** stage)) |
| | else: |
| | self.stage_size = [None, None] |
| |
|
| | self.Hattn = SelfAttention(depth, 'h', self.stage_size[0]) |
| | self.Wattn = SelfAttention(depth, 'w', self.stage_size[1]) |
| |
|
| | self.upsample = lambda img, size: F.interpolate(img, size=size, mode='bilinear', align_corners=True) |
| |
|
| | def forward(self, fs): |
| | fx = fs[0] |
| | for i in range(1, len(fs)): |
| | fs[i] = self.upsample(fs[i], fx.shape[-2:]) |
| | fx = torch.cat(fs[::-1], dim=1) |
| |
|
| | fx = self.conv1(fx) |
| |
|
| | Hfx = self.Hattn(fx) |
| | Wfx = self.Wattn(fx) |
| |
|
| | fx = self.conv2(Hfx + Wfx) |
| | fx = self.conv3(fx) |
| | fx = self.conv4(fx) |
| | out = self.conv5(fx) |
| |
|
| | return fx, out |