| """ |
| original code from rwightman: |
| https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py |
| """ |
|
|
| from functools import partial |
| from collections import OrderedDict |
|
|
| import torch |
| import torch.nn as nn |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.hub |
| from functools import partial |
| |
| |
|
|
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.hub |
| from functools import partial |
| import math |
|
|
| from timm.layers import DropPath, to_2tuple, trunc_normal_ |
| from timm.models import register_model |
| from timm.models.vision_transformer import _cfg, Mlp, Block |
| |
|
|
|
|
| def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): |
| """3x3 convolution with padding""" |
| return nn.Conv2d( |
| in_planes, |
| out_planes, |
| kernel_size=3, |
| stride=stride, |
| padding=dilation, |
| groups=groups, |
| bias=False, |
| dilation=dilation, |
| ) |
|
|
|
|
| def conv1x1(in_planes, out_planes, stride=1): |
| """1x1 convolution""" |
| return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) |
|
|
|
|
| def drop_path(x, drop_prob: float = 0.0, training: bool = False): |
| """ |
| Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
| This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, |
| the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... |
| See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for |
| changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use |
| 'survival rate' as the argument. |
| """ |
| if drop_prob == 0.0 or not training: |
| return x |
| keep_prob = 1 - drop_prob |
| shape = (x.shape[0],) + (1,) * ( |
| x.ndim - 1 |
| ) |
| random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) |
| random_tensor.floor_() |
| output = x.div(keep_prob) * random_tensor |
| return output |
|
|
|
|
| class BasicBlock(nn.Module): |
| __constants__ = ["downsample"] |
|
|
| def __init__(self, inplanes, planes, stride=1, downsample=None): |
| super(BasicBlock, self).__init__() |
| norm_layer = nn.BatchNorm2d |
| self.conv1 = conv3x3(inplanes, planes, stride) |
| self.bn1 = norm_layer(planes) |
| self.relu = nn.ReLU(inplace=True) |
| self.conv2 = conv3x3(planes, planes) |
| self.bn2 = norm_layer(planes) |
| self.downsample = downsample |
| self.stride = stride |
|
|
| def forward(self, x): |
| identity = x |
|
|
| out = self.conv1(x) |
| out = self.bn1(out) |
| out = self.relu(out) |
| out = self.conv2(out) |
| out = self.bn2(out) |
|
|
| if self.downsample is not None: |
| identity = self.downsample(x) |
|
|
| out += identity |
| out = self.relu(out) |
|
|
| return out |
|
|
|
|
| class DropPath(nn.Module): |
| """ |
| Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
| """ |
|
|
| def __init__(self, drop_prob=None): |
| super(DropPath, self).__init__() |
| self.drop_prob = drop_prob |
|
|
| def forward(self, x): |
| return drop_path(x, self.drop_prob, self.training) |
|
|
|
|
| class PatchEmbed(nn.Module): |
| """ |
| 2D Image to Patch Embedding |
| """ |
|
|
| def __init__( |
| self, img_size=14, patch_size=16, in_c=256, embed_dim=768, norm_layer=None |
| ): |
| super().__init__() |
| img_size = (img_size, img_size) |
| patch_size = (patch_size, patch_size) |
| self.img_size = img_size |
| self.patch_size = patch_size |
| self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) |
| self.num_patches = self.grid_size[0] * self.grid_size[1] |
|
|
| self.proj = nn.Conv2d(256, 768, kernel_size=1) |
| self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() |
|
|
| def forward(self, x): |
| B, C, H, W = x.shape |
| |
| |
| |
|
|
| |
| |
| x = self.proj(x).flatten(2).transpose(1, 2) |
| x = self.norm(x) |
| return x |
|
|
|
|
| class Attention(nn.Module): |
| def __init__( |
| self, |
| dim, |
| in_chans, |
| num_heads=8, |
| qkv_bias=False, |
| qk_scale=None, |
| attn_drop_ratio=0.0, |
| proj_drop_ratio=0.0, |
| ): |
| super(Attention, self).__init__() |
| self.num_heads = 8 |
| self.img_chanel = in_chans + 1 |
| head_dim = dim // num_heads |
| self.scale = head_dim**-0.5 |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| self.attn_drop = nn.Dropout(attn_drop_ratio) |
| self.proj = nn.Linear(dim, dim) |
| self.proj_drop = nn.Dropout(proj_drop_ratio) |
|
|
| def forward(self, x): |
| x_img = x[:, : self.img_chanel, :] |
| |
| B, N, C = x_img.shape |
| |
| qkv = ( |
| self.qkv(x_img) |
| .reshape(B, N, 3, self.num_heads, C // self.num_heads) |
| .permute(2, 0, 3, 1, 4) |
| ) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
| |
| |
| attn = (q @ k.transpose(-2, -1)) * self.scale |
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
|
|
| x_img = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| x_img = self.proj(x_img) |
| x_img = self.proj_drop(x_img) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| return x_img |
|
|
|
|
| class AttentionBlock(nn.Module): |
| __constants__ = ["downsample"] |
|
|
| def __init__(self, inplanes, planes, stride=1, downsample=None): |
| super(AttentionBlock, self).__init__() |
| norm_layer = nn.BatchNorm2d |
| self.conv1 = conv3x3(inplanes, planes, stride) |
| self.bn1 = norm_layer(planes) |
| self.relu = nn.ReLU(inplace=True) |
| self.conv2 = conv3x3(planes, planes) |
| self.bn2 = norm_layer(planes) |
| self.downsample = downsample |
| self.stride = stride |
| |
| self.inplanes = inplanes |
| self.eca_block = eca_block() |
|
|
| def forward(self, x): |
| identity = x |
|
|
| out = self.conv1(x) |
| out = self.bn1(out) |
| out = self.relu(out) |
|
|
| out = self.conv2(out) |
| out = self.bn2(out) |
| inplanes = self.inplanes |
| out = self.eca_block(out) |
| if self.downsample is not None: |
| identity = self.downsample(x) |
|
|
| out += identity |
| out = self.relu(out) |
|
|
| return out |
|
|
|
|
| class Mlp(nn.Module): |
| """ |
| MLP as used in Vision Transformer, MLP-Mixer and related networks |
| """ |
|
|
| def __init__( |
| self, |
| in_features, |
| hidden_features=None, |
| out_features=None, |
| act_layer=nn.GELU, |
| drop=0.0, |
| ): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| self.fc1 = nn.Linear(in_features, hidden_features) |
| self.act = act_layer() |
| self.fc2 = nn.Linear(hidden_features, out_features) |
| self.drop = nn.Dropout(drop) |
|
|
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.act(x) |
| x = self.drop(x) |
| x = self.fc2(x) |
| x = self.drop(x) |
| return x |
|
|
|
|
| class Block(nn.Module): |
| def __init__( |
| self, |
| dim, |
| in_chans, |
| num_heads, |
| mlp_ratio=4.0, |
| qkv_bias=False, |
| qk_scale=None, |
| drop_ratio=0.0, |
| attn_drop_ratio=0.0, |
| drop_path_ratio=0.0, |
| act_layer=nn.GELU, |
| norm_layer=nn.LayerNorm, |
| ): |
| super(Block, self).__init__() |
| self.norm1 = norm_layer(dim) |
| self.img_chanel = in_chans + 1 |
|
|
| self.conv = nn.Conv1d(self.img_chanel, self.img_chanel, 1) |
| self.attn = Attention( |
| dim, |
| in_chans=in_chans, |
| num_heads=num_heads, |
| qkv_bias=qkv_bias, |
| qk_scale=qk_scale, |
| attn_drop_ratio=attn_drop_ratio, |
| proj_drop_ratio=drop_ratio, |
| ) |
| |
| self.drop_path = ( |
| DropPath(drop_path_ratio) if drop_path_ratio > 0.0 else nn.Identity() |
| ) |
| self.norm2 = norm_layer(dim) |
| mlp_hidden_dim = int(dim * mlp_ratio) |
| self.mlp = Mlp( |
| in_features=dim, |
| hidden_features=mlp_hidden_dim, |
| act_layer=act_layer, |
| drop=drop_ratio, |
| ) |
|
|
| def forward(self, x): |
| |
| |
|
|
| x_img = x |
| |
| |
| x_img = x_img + self.drop_path(self.attn(self.norm1(x))) |
| x = x_img + self.drop_path(self.mlp(self.norm2(x_img))) |
| |
| |
| |
| |
| |
|
|
| return x |
|
|
|
|
| class ClassificationHead(nn.Module): |
| def __init__(self, input_dim: int, target_dim: int): |
| super().__init__() |
| self.linear = torch.nn.Linear(input_dim, target_dim) |
|
|
| def forward(self, x): |
| x = x.view(x.size(0), -1) |
| y_hat = self.linear(x) |
| return y_hat |
|
|
|
|
| def load_pretrained_weights(model, checkpoint): |
| import collections |
|
|
| if "state_dict" in checkpoint: |
| state_dict = checkpoint["state_dict"] |
| else: |
| state_dict = checkpoint |
| model_dict = model.state_dict() |
| new_state_dict = collections.OrderedDict() |
| matched_layers, discarded_layers = [], [] |
| for k, v in state_dict.items(): |
| |
| |
| if k.startswith("module."): |
| k = k[7:] |
| if k in model_dict and model_dict[k].size() == v.size(): |
| new_state_dict[k] = v |
| matched_layers.append(k) |
| else: |
| discarded_layers.append(k) |
| |
| model_dict.update(new_state_dict) |
|
|
| model.load_state_dict(model_dict) |
| print("load_weight", len(matched_layers)) |
| return model |
|
|
|
|
| class eca_block(nn.Module): |
| def __init__(self, channel=128, b=1, gamma=2): |
| super(eca_block, self).__init__() |
| kernel_size = int(abs((math.log(channel, 2) + b) / gamma)) |
| kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1 |
|
|
| self.avg_pool = nn.AdaptiveAvgPool2d(1) |
| self.conv = nn.Conv1d( |
| 1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False |
| ) |
| self.sigmoid = nn.Sigmoid() |
|
|
| def forward(self, x): |
| y = self.avg_pool(x) |
| y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) |
| y = self.sigmoid(y) |
| return x * y.expand_as(x) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| class eca_block(nn.Module): |
| def __init__(self, channel=196, b=1, gamma=2): |
| super(eca_block, self).__init__() |
| kernel_size = int(abs((math.log(channel, 2) + b) / gamma)) |
| kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1 |
|
|
| self.avg_pool = nn.AdaptiveAvgPool2d(1) |
| self.conv = nn.Conv1d( |
| 1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False |
| ) |
| self.sigmoid = nn.Sigmoid() |
|
|
| def forward(self, x): |
| y = self.avg_pool(x) |
| y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) |
| y = self.sigmoid(y) |
| return x * y.expand_as(x) |
|
|
|
|
| class SE_block(nn.Module): |
| def __init__(self, input_dim: int): |
| super().__init__() |
| self.linear1 = torch.nn.Linear(input_dim, input_dim) |
| self.relu = nn.ReLU() |
| self.linear2 = torch.nn.Linear(input_dim, input_dim) |
| self.sigmod = nn.Sigmoid() |
|
|
| def forward(self, x): |
| x1 = self.linear1(x) |
| x1 = self.relu(x1) |
| x1 = self.linear2(x1) |
| x1 = self.sigmod(x1) |
| x = x * x1 |
| return x |
|
|
|
|
| class VisionTransformer(nn.Module): |
| def __init__( |
| self, |
| img_size=14, |
| patch_size=14, |
| in_c=147, |
| num_classes=7, |
| embed_dim=768, |
| depth=6, |
| num_heads=8, |
| mlp_ratio=4.0, |
| qkv_bias=True, |
| qk_scale=None, |
| representation_size=None, |
| distilled=False, |
| drop_ratio=0.0, |
| attn_drop_ratio=0.0, |
| drop_path_ratio=0.0, |
| embed_layer=PatchEmbed, |
| norm_layer=None, |
| act_layer=None, |
| ): |
| """ |
| Args: |
| img_size (int, tuple): input image size |
| patch_size (int, tuple): patch size |
| in_c (int): number of input channels |
| num_classes (int): number of classes for classification head |
| embed_dim (int): embedding dimension |
| depth (int): depth of transformer |
| num_heads (int): number of attention heads |
| mlp_ratio (int): ratio of mlp hidden dim to embedding dim |
| qkv_bias (bool): enable bias for qkv if True |
| qk_scale (float): override default qk scale of head_dim ** -0.5 if set |
| representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set |
| distilled (bool): model includes a distillation token and head as in DeiT models |
| drop_ratio (float): dropout rate |
| attn_drop_ratio (float): attention dropout rate |
| drop_path_ratio (float): stochastic depth rate |
| embed_layer (nn.Module): patch embedding layer |
| norm_layer: (nn.Module): normalization layer |
| """ |
| super(VisionTransformer, self).__init__() |
| self.num_classes = num_classes |
| self.num_features = self.embed_dim = ( |
| embed_dim |
| ) |
| self.num_tokens = 2 if distilled else 1 |
| norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) |
| act_layer = act_layer or nn.GELU |
|
|
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| self.pos_embed = nn.Parameter(torch.zeros(1, in_c + 1, embed_dim)) |
| self.pos_drop = nn.Dropout(p=drop_ratio) |
|
|
| self.se_block = SE_block(input_dim=embed_dim) |
|
|
| self.patch_embed = embed_layer( |
| img_size=img_size, patch_size=patch_size, in_c=256, embed_dim=768 |
| ) |
| num_patches = self.patch_embed.num_patches |
| self.head = ClassificationHead(input_dim=embed_dim, target_dim=self.num_classes) |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| self.dist_token = ( |
| nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None |
| ) |
| |
| self.pos_drop = nn.Dropout(p=drop_ratio) |
| |
| self.eca_block = eca_block() |
|
|
| |
| |
| |
| |
|
|
| self.CON1 = nn.Conv2d(256, 768, kernel_size=1, stride=1, bias=False) |
| self.IRLinear1 = nn.Linear(1024, 768) |
| self.IRLinear2 = nn.Linear(768, 512) |
| self.eca_block = eca_block() |
| dpr = [ |
| x.item() for x in torch.linspace(0, drop_path_ratio, depth) |
| ] |
| self.blocks = nn.Sequential( |
| *[ |
| Block( |
| dim=embed_dim, |
| in_chans=in_c, |
| num_heads=num_heads, |
| mlp_ratio=mlp_ratio, |
| qkv_bias=qkv_bias, |
| qk_scale=qk_scale, |
| drop_ratio=drop_ratio, |
| attn_drop_ratio=attn_drop_ratio, |
| drop_path_ratio=dpr[i], |
| norm_layer=norm_layer, |
| act_layer=act_layer, |
| ) |
| for i in range(depth) |
| ] |
| ) |
| self.norm = norm_layer(embed_dim) |
|
|
| |
| if representation_size and not distilled: |
| self.has_logits = True |
| self.num_features = representation_size |
| self.pre_logits = nn.Sequential( |
| OrderedDict( |
| [ |
| ("fc", nn.Linear(embed_dim, representation_size)), |
| ("act", nn.Tanh()), |
| ] |
| ) |
| ) |
| else: |
| self.has_logits = False |
| self.pre_logits = nn.Identity() |
|
|
| |
| |
| self.head_dist = None |
| if distilled: |
| self.head_dist = ( |
| nn.Linear(self.embed_dim, self.num_classes) |
| if num_classes > 0 |
| else nn.Identity() |
| ) |
|
|
| |
| nn.init.trunc_normal_(self.pos_embed, std=0.02) |
| if self.dist_token is not None: |
| nn.init.trunc_normal_(self.dist_token, std=0.02) |
|
|
| nn.init.trunc_normal_(self.cls_token, std=0.02) |
| self.apply(_init_vit_weights) |
|
|
| def forward_features(self, x): |
| |
| |
| |
| |
|
|
| cls_token = self.cls_token.expand(x.shape[0], -1, -1) |
| if self.dist_token is None: |
| x = torch.cat((cls_token, x), dim=1) |
| else: |
| x = torch.cat( |
| (cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1 |
| ) |
| |
| x = self.pos_drop(x + self.pos_embed) |
| x = self.blocks(x) |
| x = self.norm(x) |
| if self.dist_token is None: |
| return self.pre_logits(x[:, 0]) |
| else: |
| return x[:, 0], x[:, 1] |
|
|
| def forward(self, x): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| x = self.forward_features(x) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| x = self.se_block(x) |
|
|
| x1 = self.head(x) |
|
|
| return x1 |
|
|
|
|
| def _init_vit_weights(m): |
| """ |
| ViT weight initialization |
| :param m: module |
| """ |
| if isinstance(m, nn.Linear): |
| nn.init.trunc_normal_(m.weight, std=0.01) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Conv2d): |
| nn.init.kaiming_normal_(m.weight, mode="fan_out") |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.zeros_(m.bias) |
| nn.init.ones_(m.weight) |
|
|
|
|
| def vit_base_patch16_224(num_classes: int = 7): |
| """ |
| ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). |
| ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer. |
| weights ported from official Google JAX impl: |
| 链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA 密码: eu9f |
| """ |
| model = VisionTransformer( |
| img_size=224, |
| patch_size=16, |
| embed_dim=768, |
| depth=12, |
| num_heads=12, |
| representation_size=None, |
| num_classes=num_classes, |
| ) |
|
|
| return model |
|
|
|
|
| def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True): |
| """ |
| ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). |
| ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. |
| weights ported from official Google JAX impl: |
| https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth |
| """ |
| model = VisionTransformer( |
| img_size=224, |
| patch_size=16, |
| embed_dim=768, |
| depth=12, |
| num_heads=12, |
| representation_size=768 if has_logits else None, |
| num_classes=num_classes, |
| ) |
| return model |
|
|
|
|
| def vit_base_patch32_224(num_classes: int = 1000): |
| """ |
| ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). |
| ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer. |
| weights ported from official Google JAX impl: |
| 链接: https://pan.baidu.com/s/1hCv0U8pQomwAtHBYc4hmZg 密码: s5hl |
| """ |
| model = VisionTransformer( |
| img_size=224, |
| patch_size=32, |
| embed_dim=768, |
| depth=12, |
| num_heads=12, |
| representation_size=None, |
| num_classes=num_classes, |
| ) |
| return model |
|
|
|
|
| def vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True): |
| """ |
| ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). |
| ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. |
| weights ported from official Google JAX impl: |
| https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth |
| """ |
| model = VisionTransformer( |
| img_size=224, |
| patch_size=32, |
| embed_dim=768, |
| depth=12, |
| num_heads=12, |
| representation_size=768 if has_logits else None, |
| num_classes=num_classes, |
| ) |
| return model |
|
|
|
|
| def vit_large_patch16_224(num_classes: int = 1000): |
| """ |
| ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). |
| ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer. |
| weights ported from official Google JAX impl: |
| 链接: https://pan.baidu.com/s/1cxBgZJJ6qUWPSBNcE4TdRQ 密码: qqt8 |
| """ |
| model = VisionTransformer( |
| img_size=224, |
| patch_size=16, |
| embed_dim=1024, |
| depth=24, |
| num_heads=16, |
| representation_size=None, |
| num_classes=num_classes, |
| ) |
| return model |
|
|
|
|
| def vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True): |
| """ |
| ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). |
| ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. |
| weights ported from official Google JAX impl: |
| https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth |
| """ |
| model = VisionTransformer( |
| img_size=224, |
| patch_size=16, |
| embed_dim=1024, |
| depth=24, |
| num_heads=16, |
| representation_size=1024 if has_logits else None, |
| num_classes=num_classes, |
| ) |
| return model |
|
|
|
|
| def vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True): |
| """ |
| ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). |
| ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. |
| weights ported from official Google JAX impl: |
| https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth |
| """ |
| model = VisionTransformer( |
| img_size=224, |
| patch_size=32, |
| embed_dim=1024, |
| depth=24, |
| num_heads=16, |
| representation_size=1024 if has_logits else None, |
| num_classes=num_classes, |
| ) |
| return model |
|
|
|
|
| def vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True): |
| """ |
| ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). |
| ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. |
| NOTE: converted weights not currently available, too large for github release hosting. |
| """ |
| model = VisionTransformer( |
| img_size=224, |
| patch_size=14, |
| embed_dim=1280, |
| depth=32, |
| num_heads=16, |
| representation_size=1280 if has_logits else None, |
| num_classes=num_classes, |
| ) |
| return model |
|
|