| | import cv2 |
| | import torch |
| | from torch import nn |
| | from einops.layers.torch import Rearrange |
| | from .DCT import Learnable_DCT2D |
| | |
| |
|
| | class Block(nn.Module): |
| | """ ConvNeXtV2 Block. |
| | |
| | Args: |
| | dim (int): Number of input channels. |
| | drop_path (float): Stochastic depth rate. Default: 0.0 |
| | """ |
| |
|
| | def __init__(self, dim, drop_path=0.): |
| | super().__init__() |
| | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) |
| | self.norm = LayerNorm(dim, eps=1e-6) |
| | self.pwconv1 = nn.Linear(dim, 4 * dim) |
| | self.act = nn.GELU() |
| | self.grn = GRN(4 * dim) |
| | self.pwconv2 = nn.Linear(4 * dim, dim) |
| | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| | self.attention = Spatial_Attention() |
| | def forward(self, x): |
| | input = x |
| | x = self.dwconv(x) |
| | x = x.permute(0, 2, 3, 1) |
| | x = self.norm(x) |
| | x = self.pwconv1(x) |
| | x = self.act(x) |
| | x = self.grn(x) |
| | x = self.pwconv2(x) |
| |
|
| | x = x.permute(0, 3, 1, 2) |
| | attention = self.attention(x) |
| | x = x * nn.UpsamplingBilinear2d(x.shape[2:])(attention) |
| | x = input + self.drop_path(x) |
| | return x |
| |
|
| | class Spatial_Attention(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.avgpool = nn.AdaptiveAvgPool2d((7,7)) |
| | self.conv = nn.Conv2d(2,1, kernel_size=7, padding=3) |
| | self.attention = TransformerBlock(1, 1, heads=1, dim_head=1, img_size=[7,7]) |
| |
|
| | def forward(self, x): |
| | x_avg = x.mean([1]).unsqueeze(1) |
| | x_max = x.max(dim=1).values.unsqueeze(1) |
| | |
| | x = torch.cat([x_avg, x_max], dim=1) |
| | x = self.avgpool(x) |
| | x = self.conv(x) |
| | x = self.attention(x) |
| | return x |
| |
|
| | class TransformerBlock(nn.Module): |
| | def __init__(self, inp, oup, heads=8, dim_head=32, img_size=None, downsample=False, dropout=0.): |
| | super().__init__() |
| | hidden_dim = int(inp * 4) |
| |
|
| | self.downsample = downsample |
| | self.ih, self.iw = img_size |
| |
|
| | if self.downsample: |
| | self.pool1 = nn.MaxPool2d(3, 2, 1) |
| | self.pool2 = nn.MaxPool2d(3, 2, 1) |
| | self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False) |
| |
|
| | self.attn = Attention(inp, oup, heads, dim_head, dropout) |
| | self.ff = FeedForward(oup, hidden_dim, dropout) |
| |
|
| | self.attn = nn.Sequential( |
| | Rearrange('b c ih iw -> b (ih iw) c'), |
| | PreNorm(inp, self.attn, nn.LayerNorm), |
| | Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw) |
| | ) |
| |
|
| | self.ff = nn.Sequential( |
| | Rearrange('b c ih iw -> b (ih iw) c'), |
| | PreNorm(oup, self.ff, nn.LayerNorm), |
| | Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw) |
| | ) |
| |
|
| | def forward(self, x): |
| | if self.downsample: |
| | x = self.proj(self.pool1(x)) + self.attn(self.pool2(x)) |
| | else: |
| | x = x + self.attn(x) |
| | x = x + self.ff(x) |
| | return x |
| |
|
| |
|
| | class CSATv2(nn.Module): |
| | def __init__(self, img_size=None, num_classes=1000, drop_path_rate=0, head_init_scale=1): |
| | super().__init__() |
| | dims = [32, 72, 168, 386] |
| | channel_order = "channels_first" |
| | depths = [2, 2, 6, 4] |
| | dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] |
| |
|
| | |
| | |
| |
|
| | self.stages1 = nn.Sequential( |
| | Block(dim=dims[0], drop_path=dp_rates[0]), |
| | Block(dim=dims[0], drop_path=dp_rates[1]), |
| | LayerNorm(dims[0], eps=1e-6, data_format=channel_order), |
| | nn.Conv2d(dims[0], dims[0 + 1], kernel_size=2, stride=2), |
| | ) |
| |
|
| | self.stages2 = nn.Sequential( |
| | Block(dim=dims[1], drop_path=dp_rates[0]), |
| | Block(dim=dims[1], drop_path=dp_rates[1]), |
| | LayerNorm(dims[1], eps=1e-6, data_format=channel_order), |
| | nn.Conv2d(dims[1], dims[1 + 1], kernel_size=2, stride=2), |
| | ) |
| |
|
| | self.stages3 = nn.Sequential( |
| | Block(dim=dims[2], drop_path=dp_rates[0]), |
| | Block(dim=dims[2], drop_path=dp_rates[1]), |
| | Block(dim=dims[2], drop_path=dp_rates[2]), |
| | Block(dim=dims[2], drop_path=dp_rates[3]), |
| | Block(dim=dims[2], drop_path=dp_rates[4]), |
| | Block(dim=dims[2], drop_path=dp_rates[5]), |
| | TransformerBlock(inp=dims[2], oup=dims[2], img_size=[int(img_size / 32), int(img_size / 32)]), |
| | TransformerBlock(inp=dims[2], oup=dims[2], img_size=[int(img_size / 32), int(img_size / 32)]), |
| | LayerNorm(dims[2], eps=1e-6, data_format=channel_order), |
| | nn.Conv2d(dims[2], dims[2 + 1], kernel_size=2, stride=2), |
| | ) |
| |
|
| | self.stages4 = nn.Sequential( |
| | Block(dim=dims[3], drop_path=dp_rates[0]), |
| | Block(dim=dims[3], drop_path=dp_rates[1]), |
| | Block(dim=dims[3], drop_path=dp_rates[2]), |
| | Block(dim=dims[3], drop_path=dp_rates[3]), |
| | TransformerBlock(inp=dims[3], oup=dims[3], img_size=[int(img_size / 64), int(img_size / 64)]), |
| | TransformerBlock(inp=dims[3], oup=dims[3], img_size=[int(img_size / 64), int(img_size / 64)]), |
| | ) |
| |
|
| | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) |
| | self.head = nn.Linear(dims[-1], num_classes) |
| |
|
| | self.apply(self._init_weights) |
| | self.head.weight.data.mul_(head_init_scale) |
| | self.head.bias.data.mul_(head_init_scale) |
| | self.dct = Learnable_DCT2D(8) |
| | |
| |
|
| | def load_checkpoint(self, checkpoint): |
| | state = torch.load(checkpoint, map_location='cpu') |
| | try: |
| | state_dict = state['state_dict'] |
| | except: |
| | state_dict = state['model'] |
| | for key in list(state_dict.keys()): |
| | state_dict[key.replace('module.backbone.', '').replace('resnet.', '')] = state_dict.pop(key) |
| |
|
| | model_dict = self.state_dict() |
| | weights = {k: v for k, v in state_dict.items() if k in model_dict} |
| |
|
| | model_dict.update(weights) |
| | del model_dict['head.bias'] |
| | del model_dict['head.weight'] |
| | self.load_state_dict(model_dict, strict=False) |
| |
|
| | def preprocess(self, x): |
| | x = cv2.cvtColor(x, cv2.COLOR_BGR2YCR_CB) |
| | return x |
| |
|
| | def _init_weights(self, m): |
| | if isinstance(m, (nn.Conv2d, nn.Linear)): |
| | trunc_normal_(m.weight, std=.02) |
| | try: |
| | nn.init.constant_(m.bias, 0) |
| | except: |
| | pass |
| | |
| |
|
| |
|
| | def forward(self, x): |
| | |
| | x = self.dct(x) |
| | x = self.stages1(x) |
| | x = self.stages2(x) |
| | x = self.stages3(x) |
| | x = self.stages4(x) |
| | x = self.norm(x.mean([-2, -1])) |
| | x = self.head(x) |
| | return x |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from einops import rearrange |
| | import math |
| | import warnings |
| |
|
| | class LayerNorm(nn.Module): |
| | """ LayerNorm that supports two data formats: channels_last (default) or channels_first. |
| | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with |
| | shape (batch_size, height, width, channels) while channels_first corresponds to inputs |
| | with shape (batch_size, channels, height, width). |
| | """ |
| |
|
| | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): |
| | super().__init__() |
| | self.weight = nn.Parameter(torch.ones(normalized_shape)) |
| | self.bias = nn.Parameter(torch.zeros(normalized_shape)) |
| | self.eps = eps |
| | self.data_format = data_format |
| | if self.data_format not in ["channels_last", "channels_first"]: |
| | raise NotImplementedError |
| | self.normalized_shape = (normalized_shape,) |
| |
|
| | def forward(self, x): |
| | if self.data_format == "channels_last": |
| | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) |
| | elif self.data_format == "channels_first": |
| | u = x.mean(1, keepdim=True) |
| | s = (x - u).pow(2).mean(1, keepdim=True) |
| | x = (x - u) / torch.sqrt(s + self.eps) |
| | x = self.weight[:, None, None] * x + self.bias[:, None, None] |
| | return x |
| |
|
| |
|
| | class GRN(nn.Module): |
| | """ GRN (Global Response Normalization) layer |
| | """ |
| |
|
| | def __init__(self, dim): |
| | super().__init__() |
| | self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) |
| | self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) |
| |
|
| | def forward(self, x): |
| | Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) |
| | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) |
| | return self.gamma * (x * Nx) + self.beta + x |
| |
|
| | def drop_path(x, drop_prob: float = 0., training: bool = False): |
| | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
| | |
| | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, |
| | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... |
| | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for |
| | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use |
| | 'survival rate' as the argument. |
| | |
| | """ |
| | if drop_prob == 0. or not training: |
| | return x |
| | keep_prob = 1 - drop_prob |
| | shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
| | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) |
| | random_tensor.floor_() |
| | output = x.div(keep_prob) * random_tensor |
| | return output |
| |
|
| |
|
| | class DropPath(nn.Module): |
| | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
| | """ |
| | def __init__(self, drop_prob=None): |
| | super(DropPath, self).__init__() |
| | self.drop_prob = drop_prob |
| |
|
| | def forward(self, x): |
| | return drop_path(x, self.drop_prob, self.training) |
| |
|
| | class FeedForward(nn.Module): |
| | def __init__(self, dim, hidden_dim, dropout=0.): |
| | super().__init__() |
| | self.net = nn.Sequential( |
| | nn.Linear(dim, hidden_dim), |
| | nn.GELU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(hidden_dim, dim), |
| | nn.Dropout(dropout) |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.net(x) |
| |
|
| | class PreNorm(nn.Module): |
| | def __init__(self, dim, fn, norm): |
| | super().__init__() |
| | self.norm = norm(dim) |
| | self.fn = fn |
| |
|
| | def forward(self, x, **kwargs): |
| | return self.fn(self.norm(x), **kwargs) |
| |
|
| | class Attention(nn.Module): |
| | def __init__(self, inp, oup, heads=8, dim_head=32, dropout=0.): |
| | super().__init__() |
| | inner_dim = dim_head * heads |
| | project_out = not (heads == 1 and dim_head == inp) |
| |
|
| | |
| | self.heads = heads |
| | self.scale = dim_head ** -0.5 |
| |
|
| | self.attend = nn.Softmax(dim=-1) |
| | self.to_qkv = nn.Linear(inp, inner_dim * 3, bias=False) |
| |
|
| | self.to_out = nn.Sequential( |
| | nn.Linear(inner_dim, oup), |
| | nn.Dropout(dropout) |
| | ) if project_out else nn.Identity() |
| | self.pos_embed = PosCNN(in_chans=inp) |
| |
|
| | def forward(self, x): |
| | x = self.pos_embed(x) |
| | qkv = self.to_qkv(x).chunk(3, dim=-1) |
| | q, k, v = map(lambda t: rearrange( |
| | t, 'b n (h d) -> b h n d', h=self.heads), qkv) |
| |
|
| | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale |
| | attn = self.attend(dots) |
| | out = torch.matmul(attn, v) |
| | out = rearrange(out, 'b h n d -> b n (h d)') |
| | out = self.to_out(out) |
| | return out |
| |
|
| | |
| | class PosCNN(nn.Module): |
| | def __init__(self, in_chans): |
| | super(PosCNN, self).__init__() |
| | self.proj = nn.Conv2d(in_chans, in_chans, kernel_size=3, stride = 1, padding=1, bias=True, groups=in_chans) |
| |
|
| | def forward(self, x): |
| | B, N, C = x.shape |
| | feat_token = x |
| | H, W = int(N**0.5), int(N**0.5) |
| | cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W) |
| | x = self.proj(cnn_feat) + cnn_feat |
| | x = x.flatten(2).transpose(1, 2) |
| | return x |
| |
|
| | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): |
| | |
| | r"""Fills the input Tensor with values drawn from a truncated |
| | normal distribution. The values are effectively drawn from the |
| | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` |
| | with values outside :math:`[a, b]` redrawn until they are within |
| | the bounds. The method used for generating the random values works |
| | best when :math:`a \leq \text{mean} \leq b`. |
| | Args: |
| | tensor: an n-dimensional `torch.Tensor` |
| | mean: the mean of the normal distribution |
| | std: the standard deviation of the normal distribution |
| | a: the minimum cutoff value |
| | b: the maximum cutoff value |
| | Examples: |
| | >>> w = torch.empty(3, 5) |
| | >>> nn.init.trunc_normal_(w) |
| | """ |
| | return _no_grad_trunc_normal_(tensor, mean, std, a, b) |
| |
|
| | def _no_grad_trunc_normal_(tensor, mean, std, a, b): |
| | |
| | |
| | def norm_cdf(x): |
| | |
| | return (1. + math.erf(x / math.sqrt(2.))) / 2. |
| |
|
| | if (mean < a - 2 * std) or (mean > b + 2 * std): |
| | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " |
| | "The distribution of values may be incorrect.", |
| | stacklevel=2) |
| |
|
| | with torch.no_grad(): |
| | |
| | |
| | |
| | l = norm_cdf((a - mean) / std) |
| | u = norm_cdf((b - mean) / std) |
| |
|
| | |
| | |
| | tensor.uniform_(2 * l - 1, 2 * u - 1) |
| |
|
| | |
| | |
| | tensor.erfinv_() |
| |
|
| | |
| | tensor.mul_(std * math.sqrt(2.)) |
| | tensor.add_(mean) |
| |
|
| | |
| | tensor.clamp_(min=a, max=b) |
| | return tensor |