| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class ChannelAttention(nn.Module): | |
| def __init__(self, in_channels): | |
| super(ChannelAttention, self).__init__() | |
| self.in_channels = in_channels | |
| self.linear_1 = nn.Linear(self.in_channels, self.in_channels // 4) | |
| self.linear_2 = nn.Linear(self.in_channels // 4, self.in_channels) | |
| def forward(self, input_): | |
| n_b, n_c, h, w = input_.size() | |
| feats = F.adaptive_avg_pool2d(input_, (1, 1)).view((n_b, n_c)) | |
| feats = F.relu(self.linear_1(feats)) | |
| feats = torch.sigmoid(self.linear_2(feats)) | |
| feats = feats.view((n_b, n_c, 1, 1)) | |
| feats = feats.expand_as(input_).clone() | |
| outfeats = torch.mul(feats, input_) | |
| return outfeats | |