| | import torch |
| | import torch.nn as nn |
| | from typing import Optional, Callable, Union, Tuple, Any |
| | import torch |
| | from torch import nn, Tensor |
| | import numpy as np |
| | from typing import Optional |
| | import math |
| | from torch import nn |
| |
|
| | def makeDivisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: |
| | if min_value is None: |
| | min_value = divisor |
| | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) |
| | if new_v < 0.9 * v: |
| | new_v += divisor |
| | return new_v |
| | def callMethod(self, ElementName): |
| | return getattr(self, ElementName) |
| | def setMethod(self, ElementName, ElementValue): |
| | return setattr(self, ElementName, ElementValue) |
| | def shuffleTensor(Feature: Tensor, Mode: int=1) -> Tensor: |
| | if isinstance(Feature, Tensor): |
| | Feature = [Feature] |
| | Indexs = None |
| | Output = [] |
| | for f in Feature: |
| | B, C, H, W = f.shape |
| | if Mode == 1: |
| | f = f.flatten(2) |
| | if Indexs is None: |
| | Indexs = torch.randperm(f.shape[-1], device=f.device) |
| | f = f[:, :, Indexs.to(f.device)] |
| | f = f.reshape(B, C, H, W) |
| | else: |
| | if Indexs is None: |
| | Indexs = [torch.randperm(H, device=f.device), |
| | torch.randperm(W, device=f.device)] |
| | f = f[:, :, Indexs[0].to(f.device)] |
| | f = f[:, :, :, Indexs[1].to(f.device)] |
| | Output.append(f) |
| | return Output |
| | class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d): |
| | def __init__(self, output_size: int or tuple=1): |
| | super(AdaptiveAvgPool2d, self).__init__(output_size=output_size) |
| |
|
| | def profileModule(self, Input: Tensor): |
| | Output = self.forward(Input) |
| | return Output, 0.0, 0.0 |
| |
|
| | class AdaptiveMaxPool2d(nn.AdaptiveMaxPool2d): |
| | def __init__(self, output_size: int or tuple=1): |
| | super(AdaptiveMaxPool2d, self).__init__(output_size=output_size) |
| |
|
| | def profileModule(self, Input: Tensor): |
| | Output = self.forward(Input) |
| | return Output, 0.0, 0.0 |
| | class BaseConv2d(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels: int, |
| | out_channels: int, |
| | kernel_size: int, |
| | stride: Optional[int] = 1, |
| | padding: Optional[int] = None, |
| | groups: Optional[int] = 1, |
| | bias: Optional[bool] = None, |
| | BNorm: bool = False, |
| | ActLayer: Optional[Callable[..., nn.Module]] = None, |
| | dilation: int = 1, |
| | Momentum: Optional[float] = 0.1, |
| | **kwargs: Any |
| | ) -> None: |
| | super(BaseConv2d, self).__init__() |
| | if padding is None: |
| | padding = int((kernel_size - 1) // 2 * dilation) |
| |
|
| | if bias is None: |
| | bias = not BNorm |
| |
|
| | self.in_channels = in_channels |
| | self.out_channels = out_channels |
| | self.kernel_size = kernel_size |
| | self.stride = stride |
| | self.padding = padding |
| | self.groups = groups |
| | self.bias = bias |
| |
|
| | self.Conv = nn.Conv2d(in_channels, out_channels, |
| | kernel_size, stride, padding, dilation, groups, bias, **kwargs) |
| |
|
| | self.Bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=Momentum) if BNorm else nn.Identity() |
| |
|
| | if ActLayer is not None: |
| | if isinstance(list(ActLayer().named_modules())[0][1], nn.Sigmoid): |
| | self.Act = ActLayer() |
| | else: |
| | self.Act = ActLayer(inplace=True) |
| | else: |
| | self.Act = ActLayer |
| |
|
| | self.apply(initWeight) |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | x = self.Conv(x) |
| | x = self.Bn(x) |
| | if self.Act is not None: |
| | x = self.Act(x) |
| | return x |
| |
|
| | NormLayerTuple = ( |
| | nn.BatchNorm1d, |
| | nn.BatchNorm2d, |
| | nn.SyncBatchNorm, |
| | nn.LayerNorm, |
| | nn.InstanceNorm1d, |
| | nn.InstanceNorm2d, |
| | nn.GroupNorm, |
| | nn.BatchNorm3d, |
| | ) |
| | def initWeight(Module): |
| | if Module is None: |
| | return |
| | elif isinstance(Module, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d)): |
| | nn.init.kaiming_uniform_(Module.weight, a=math.sqrt(5)) |
| | if Module.bias is not None: |
| | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(Module.weight) |
| | if fan_in != 0: |
| | bound = 1 / math.sqrt(fan_in) |
| | nn.init.uniform_(Module.bias, -bound, bound) |
| | elif isinstance(Module, NormLayerTuple): |
| | if Module.weight is not None: |
| | nn.init.ones_(Module.weight) |
| | if Module.bias is not None: |
| | nn.init.zeros_(Module.bias) |
| | elif isinstance(Module, nn.Linear): |
| | nn.init.kaiming_uniform_(Module.weight, a=math.sqrt(5)) |
| | if Module.bias is not None: |
| | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(Module.weight) |
| | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 |
| | nn.init.uniform_(Module.bias, -bound, bound) |
| | elif isinstance(Module, (nn.Sequential, nn.ModuleList)): |
| | for m in Module: |
| | initWeight(m) |
| | elif list(Module.children()): |
| | for m in Module.children(): |
| | initWeight(m) |
| | class Attention(nn.Module): |
| | def __init__( |
| | self, |
| | InChannels: int, |
| | HidChannels: int = None, |
| | SqueezeFactor: int = 4, |
| | PoolRes: list = [1, 2, 3], |
| | Act: Callable[..., nn.Module] = nn.ReLU, |
| | ScaleAct: Callable[..., nn.Module] = nn.Sigmoid, |
| | MoCOrder: bool = True, |
| | **kwargs: Any, |
| | ) -> None: |
| | super().__init__() |
| | if HidChannels is None: |
| | HidChannels = max(makeDivisible(InChannels // SqueezeFactor, 8), 32) |
| |
|
| | AllPoolRes = PoolRes + [1] if 1 not in PoolRes else PoolRes |
| | for k in AllPoolRes: |
| | Pooling = AdaptiveAvgPool2d(k) |
| | setMethod(self, 'Pool%d' % k, Pooling) |
| |
|
| | self.SELayer = nn.Sequential( |
| | BaseConv2d(InChannels, HidChannels, 1, ActLayer=Act), |
| | BaseConv2d(HidChannels, InChannels, 1, ActLayer=ScaleAct), |
| | ) |
| |
|
| | self.PoolRes = PoolRes |
| | self.MoCOrder = MoCOrder |
| |
|
| | def RandomSample(self, x: Tensor) -> Tensor: |
| | if self.training: |
| | PoolKeep = np.random.choice(self.PoolRes) |
| | x1 = shuffleTensor(x)[0] if self.MoCOrder else x |
| | AttnMap: Tensor = callMethod(self, 'Pool%d' % PoolKeep)(x1) |
| | if AttnMap.shape[-1] > 1: |
| | AttnMap = AttnMap.flatten(2) |
| | AttnMap = AttnMap[:, :, torch.randperm(AttnMap.shape[-1])[0]] |
| | AttnMap = AttnMap[:, :, None, None] |
| | else: |
| | AttnMap: Tensor = callMethod(self, 'Pool%d' % 1)(x) |
| |
|
| | return AttnMap |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | AttnMap = self.RandomSample(x) |
| | return x * self.SELayer(AttnMap) |
| |
|
| | def channel_shuffle(x, groups): |
| | batchsize, num_channels, height, width = x.data.size() |
| | channels_per_group = num_channels // groups |
| | x = x.view(batchsize, groups, channels_per_group, height, width) |
| | x = torch.transpose(x, 1, 2).contiguous() |
| | x = x.view(batchsize, -1, height, width) |
| | return x |
| | class GLFA(nn.Module): |
| | def __init__(self, in_channels): |
| | super(GLFA, self).__init__() |
| | self.in_channels = in_channels |
| | self.out_channels = in_channels |
| | self.conv_1 = nn.Sequential( |
| | nn.Conv2d(in_channels, in_channels, padding=1, kernel_size=3, dilation=1), |
| | nn.BatchNorm2d(in_channels), |
| | nn.ReLU(inplace=True) |
| | ) |
| | self.conv_2 = nn.Sequential( |
| | nn.Conv2d(in_channels, in_channels, padding=2, kernel_size=3, dilation=2), |
| | nn.BatchNorm2d(in_channels), |
| | nn.ReLU(inplace=True) |
| | ) |
| | self.conv_3 = nn.Sequential( |
| | nn.Conv2d(in_channels, in_channels, padding=3, kernel_size=3, dilation=3), |
| | nn.BatchNorm2d(in_channels), |
| | nn.ReLU(inplace=True) |
| | ) |
| | self.conv_4 = nn.Sequential( |
| | nn.Conv2d(in_channels, in_channels, padding=4, kernel_size=3, dilation=4), |
| | nn.BatchNorm2d(in_channels), |
| | nn.ReLU(inplace=True) |
| | ) |
| | self.fuse = nn.Sequential( |
| | nn.Conv2d(in_channels * 4, in_channels, kernel_size=1, padding=0), |
| | nn.BatchNorm2d(in_channels), |
| | nn.ReLU(inplace=True) |
| | ) |
| | self.mca = Attention(InChannels=in_channels, HidChannels=16) |
| | def forward(self, x): |
| | d = x |
| | c1 = self.conv_1(x) |
| | c2 = self.conv_2(x) |
| | c3 = self.conv_3(x) |
| | c4 = self.conv_4(x) |
| | cat = torch.cat([c1, c2, c3, c4], dim=1) |
| | cat = channel_shuffle(cat, groups=4) |
| | M= self.fuse(cat) |
| | O = self.mca(M) |
| | return O + d |
| |
|