|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
from enum import Enum |
|
|
|
|
|
import einops |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch import nn |
|
|
|
|
|
|
|
|
from rscd.models.decoderheads.vision_lstm_util import interpolate_sincos, to_ntuple, VitPatchEmbed, VitPosEmbed2d, DropPath |
|
|
|
|
|
class SequenceTraversal(Enum): |
|
|
ROWWISE_FROM_TOP_LEFT = "rowwise_from_top_left" |
|
|
ROWWISE_FROM_BOT_RIGHT = "rowwise_from_bot_right" |
|
|
|
|
|
|
|
|
def bias_linspace_init_(param: torch.Tensor, start: float = 3.4, end: float = 6.0) -> torch.Tensor: |
|
|
"""Linearly spaced bias init across dimensions.""" |
|
|
assert param.dim() == 1, f"param must be 1-dimensional (typically a bias), got {param.dim()}" |
|
|
n_dims = param.shape[0] |
|
|
init_vals = torch.linspace(start, end, n_dims) |
|
|
with torch.no_grad(): |
|
|
param.copy_(init_vals) |
|
|
return param |
|
|
|
|
|
|
|
|
def small_init_(param: torch.Tensor, dim: int) -> torch.Tensor: |
|
|
""" |
|
|
Fills the input Tensor with values according to the method described in Transformers without Tears: Improving |
|
|
the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2019), using a normal distribution. |
|
|
Adopted from https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py. |
|
|
""" |
|
|
std = math.sqrt(2 / (5 * dim)) |
|
|
torch.nn.init.normal_(param, mean=0.0, std=std) |
|
|
return param |
|
|
|
|
|
|
|
|
def wang_init_(param: torch.Tensor, dim: int, num_blocks: int): |
|
|
""" Adopted from https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py. """ |
|
|
std = 2 / num_blocks / math.sqrt(dim) |
|
|
torch.nn.init.normal_(param, mean=0.0, std=std) |
|
|
return param |
|
|
|
|
|
|
|
|
def parallel_stabilized_simple( |
|
|
queries: torch.Tensor, |
|
|
keys: torch.Tensor, |
|
|
values: torch.Tensor, |
|
|
igate_preact: torch.Tensor, |
|
|
fgate_preact: torch.Tensor, |
|
|
lower_triangular_matrix: torch.Tensor = None, |
|
|
stabilize_rowwise: bool = True, |
|
|
eps: float = 1e-6, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
This is the mLSTM cell in parallel form. |
|
|
This version is stabilized. We control the range of exp() arguments by |
|
|
ensuring that they are always smaller than 0.0 by subtracting the maximum. |
|
|
|
|
|
Args: |
|
|
:param queries: (torch.Tensor) (B, NH, S, DH) |
|
|
:param keys: (torch.Tensor) (B, NH, S, DH) |
|
|
:param values: (torch.Tensor) (B, NH, S, DH) |
|
|
:param igate_preact: (torch.Tensor) (B, NH, S, 1) |
|
|
:param fgate_preact: (torch.Tensor) (B, NH, S, 1) |
|
|
:param lower_triangular_matrix: (torch.Tensor) (S,S). Defaults to None. |
|
|
:param stabilize_rowwise: (bool) Wether to stabilize the combination matrix C rowwise (take maximum per row). |
|
|
Alternative: Subtract the maximum over all rows. Defaults to True. |
|
|
:param eps: (float) small constant to avoid division by 0. Defaults to 1e-6. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: (B, NH, S, DH), h_tilde_state |
|
|
""" |
|
|
|
|
|
B, NH, S, DH = queries.shape |
|
|
_dtype, _device = queries.dtype, queries.device |
|
|
|
|
|
|
|
|
log_fgates = torch.nn.functional.logsigmoid(fgate_preact) |
|
|
if lower_triangular_matrix is None or S < lower_triangular_matrix.size(-1): |
|
|
ltr = torch.tril(torch.ones((S, S), dtype=torch.bool, device=_device)) |
|
|
else: |
|
|
ltr = lower_triangular_matrix |
|
|
assert ltr.dtype == torch.bool, f"lower_triangular_matrix must be of dtype bool, got {ltr.dtype}" |
|
|
|
|
|
log_fgates_cumsum = torch.cat( |
|
|
[ |
|
|
torch.zeros((B, NH, 1, 1), dtype=_dtype, device=_device), |
|
|
torch.cumsum(log_fgates, dim=-2), |
|
|
], |
|
|
dim=-2, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
rep_log_fgates_cumsum = log_fgates_cumsum.repeat(1, 1, 1, S + 1) |
|
|
|
|
|
|
|
|
_log_fg_matrix = rep_log_fgates_cumsum - rep_log_fgates_cumsum.transpose(-2, -1) |
|
|
|
|
|
|
|
|
log_fg_matrix = torch.where(ltr, _log_fg_matrix[:, :, 1:, 1:], -float("inf")) |
|
|
|
|
|
|
|
|
log_D_matrix = log_fg_matrix + igate_preact.transpose(-2, -1) |
|
|
|
|
|
if stabilize_rowwise: |
|
|
max_log_D, _ = torch.max(log_D_matrix, dim=-1, keepdim=True) |
|
|
else: |
|
|
max_log_D = torch.max(log_D_matrix.view(B, NH, -1), dim=-1, keepdim=True)[0].unsqueeze(-1) |
|
|
|
|
|
log_D_matrix_stabilized = log_D_matrix - max_log_D |
|
|
D_matrix = torch.exp(log_D_matrix_stabilized) |
|
|
|
|
|
keys_scaled = keys / math.sqrt(DH) |
|
|
|
|
|
|
|
|
qk_matrix = queries @ keys_scaled.transpose(-2, -1) |
|
|
C_matrix = qk_matrix * D_matrix |
|
|
normalizer = torch.maximum(C_matrix.sum(dim=-1, keepdim=True).abs(), torch.exp(-max_log_D)) |
|
|
|
|
|
C_matrix_normalized = C_matrix / (normalizer + eps) |
|
|
|
|
|
|
|
|
h_tilde_state = C_matrix_normalized @ values |
|
|
|
|
|
return h_tilde_state |
|
|
|
|
|
|
|
|
class LinearHeadwiseExpand(nn.Module): |
|
|
""" |
|
|
This is a structured projection layer that projects the input to a higher dimension. |
|
|
It only allows integer up-projection factors, i.e. the output dimension is a multiple of the input dimension. |
|
|
""" |
|
|
|
|
|
def __init__(self, dim, num_heads, bias=False): |
|
|
super().__init__() |
|
|
assert dim % num_heads == 0 |
|
|
self.dim = dim |
|
|
self.num_heads = num_heads |
|
|
|
|
|
dim_per_head = dim // num_heads |
|
|
self.weight = nn.Parameter(torch.empty(num_heads, dim_per_head, dim_per_head)) |
|
|
if bias: |
|
|
self.bias = nn.Parameter(torch.empty(dim)) |
|
|
else: |
|
|
self.bias = None |
|
|
self.reset_parameters() |
|
|
|
|
|
def reset_parameters(self): |
|
|
nn.init.normal_(self.weight.data, mean=0.0, std=math.sqrt(2 / 5 / self.weight.shape[-1])) |
|
|
if self.bias is not None: |
|
|
nn.init.zeros_(self.bias.data) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = einops.rearrange(x, "... (nh d) -> ... nh d", nh=self.num_heads) |
|
|
x = einops.einsum( |
|
|
x, |
|
|
self.weight, |
|
|
"... nh d, nh out_d d -> ... nh out_d", |
|
|
) |
|
|
x = einops.rearrange(x, "... nh out_d -> ... (nh out_d)") |
|
|
if self.bias is not None: |
|
|
x = x + self.bias |
|
|
return x |
|
|
|
|
|
def extra_repr(self): |
|
|
return ( |
|
|
f"dim={self.dim}, " |
|
|
f"num_heads={self.num_heads}, " |
|
|
f"bias={self.bias is not None}, " |
|
|
) |
|
|
|
|
|
|
|
|
class CausalConv1d(nn.Module): |
|
|
""" |
|
|
Implements causal depthwise convolution of a time series tensor. |
|
|
Input: Tensor of shape (B,T,F), i.e. (batch, time, feature) |
|
|
Output: Tensor of shape (B,T,F) |
|
|
|
|
|
Args: |
|
|
feature_dim: number of features in the input tensor |
|
|
kernel_size: size of the kernel for the depthwise convolution |
|
|
causal_conv_bias: whether to use bias in the depthwise convolution |
|
|
channel_mixing: whether to use channel mixing (i.e. groups=1) or not (i.e. groups=feature_dim) |
|
|
If True, it mixes the convolved features across channels. |
|
|
If False, all the features are convolved independently. |
|
|
""" |
|
|
|
|
|
def __init__(self, dim, kernel_size=4, bias=True): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.kernel_size = kernel_size |
|
|
self.bias = bias |
|
|
|
|
|
self.pad = kernel_size - 1 |
|
|
self.conv = nn.Conv1d( |
|
|
in_channels=dim, |
|
|
out_channels=dim, |
|
|
kernel_size=kernel_size, |
|
|
padding=self.pad, |
|
|
groups=dim, |
|
|
bias=bias, |
|
|
) |
|
|
self.reset_parameters() |
|
|
|
|
|
def reset_parameters(self): |
|
|
self.conv.reset_parameters() |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
x = einops.rearrange(x, "b l d -> b d l") |
|
|
|
|
|
x = self.conv(x) |
|
|
x = x[:, :, :-self.pad] |
|
|
|
|
|
x = einops.rearrange(x, "b d l -> b l d") |
|
|
return x |
|
|
|
|
|
|
|
|
class LayerNorm(nn.Module): |
|
|
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False. """ |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
ndim: int = -1, |
|
|
weight: bool = True, |
|
|
bias: bool = False, |
|
|
eps: float = 1e-5, |
|
|
residual_weight: bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
self.weight = nn.Parameter(torch.zeros(ndim)) if weight else None |
|
|
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None |
|
|
self.eps = eps |
|
|
self.residual_weight = residual_weight |
|
|
self.ndim = ndim |
|
|
self.reset_parameters() |
|
|
|
|
|
@property |
|
|
def weight_proxy(self) -> torch.Tensor: |
|
|
if self.weight is None: |
|
|
return None |
|
|
if self.residual_weight: |
|
|
return 1.0 + self.weight |
|
|
else: |
|
|
return self.weight |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return F.layer_norm( |
|
|
x, |
|
|
normalized_shape=(self.ndim,), |
|
|
weight=self.weight_proxy, |
|
|
bias=self.bias, |
|
|
eps=self.eps, |
|
|
) |
|
|
|
|
|
def reset_parameters(self): |
|
|
if self.weight_proxy is not None: |
|
|
if self.residual_weight: |
|
|
nn.init.zeros_(self.weight) |
|
|
else: |
|
|
nn.init.ones_(self.weight) |
|
|
if self.bias is not None: |
|
|
nn.init.zeros_(self.bias) |
|
|
|
|
|
|
|
|
class MultiHeadLayerNorm(LayerNorm): |
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
assert x.ndim == 4, "Input must be 4D tensor (B, NH, S, DH)" |
|
|
B, NH, S, DH = x.shape |
|
|
|
|
|
gn_in_1 = x.transpose(1, 2) |
|
|
gn_in_2 = gn_in_1.reshape(B * S, NH * DH) |
|
|
out = F.group_norm( |
|
|
gn_in_2, |
|
|
num_groups=NH, |
|
|
weight=self.weight_proxy, |
|
|
bias=self.bias, |
|
|
eps=self.eps, |
|
|
) |
|
|
|
|
|
out = out.view(B, S, NH, DH).transpose(1, 2) |
|
|
return out |
|
|
|
|
|
|
|
|
class MatrixLSTMCell(nn.Module): |
|
|
def __init__(self, dim, num_heads): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.num_heads = num_heads |
|
|
|
|
|
self.igate = nn.Linear(3 * dim, num_heads) |
|
|
self.fgate = nn.Linear(3 * dim, num_heads) |
|
|
self.outnorm = MultiHeadLayerNorm(ndim=dim, weight=True, bias=False) |
|
|
self.causal_mask_cache = {} |
|
|
self.reset_parameters() |
|
|
|
|
|
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: |
|
|
B, S, _ = q.shape |
|
|
|
|
|
if_gate_input = torch.cat([q, k, v], dim=-1) |
|
|
q = q.view(B, S, self.num_heads, -1) |
|
|
k = k.view(B, S, self.num_heads, -1) |
|
|
v = v.view(B, S, self.num_heads, -1) |
|
|
|
|
|
q = q.transpose(1, 2) |
|
|
k = k.transpose(1, 2) |
|
|
v = v.transpose(1, 2) |
|
|
|
|
|
|
|
|
igate_preact = self.igate(if_gate_input) |
|
|
igate_preact = igate_preact.transpose(-1, -2).unsqueeze(-1) |
|
|
fgate_preact = self.fgate(if_gate_input) |
|
|
fgate_preact = fgate_preact.transpose(-1, -2).unsqueeze(-1) |
|
|
|
|
|
|
|
|
if S in self.causal_mask_cache: |
|
|
causal_mask = self.causal_mask_cache[(S, str(q.device))] |
|
|
else: |
|
|
causal_mask = torch.tril(torch.ones(S, S, dtype=torch.bool, device=q.device)) |
|
|
self.causal_mask_cache[(S, str(q.device))] = causal_mask |
|
|
|
|
|
h_state = parallel_stabilized_simple( |
|
|
queries=q, |
|
|
keys=k, |
|
|
values=v, |
|
|
igate_preact=igate_preact, |
|
|
fgate_preact=fgate_preact, |
|
|
lower_triangular_matrix=causal_mask, |
|
|
) |
|
|
|
|
|
h_state_norm = self.outnorm(h_state) |
|
|
h_state_norm = h_state_norm.transpose(1, 2).reshape(B, S, -1) |
|
|
|
|
|
return h_state_norm |
|
|
|
|
|
def reset_parameters(self): |
|
|
self.outnorm.reset_parameters() |
|
|
|
|
|
torch.nn.init.zeros_(self.fgate.weight) |
|
|
bias_linspace_init_(self.fgate.bias, start=3.0, end=6.0) |
|
|
|
|
|
torch.nn.init.zeros_(self.igate.weight) |
|
|
torch.nn.init.normal_(self.igate.bias, mean=0.0, std=0.1) |
|
|
|
|
|
|
|
|
class ViLLayer(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim, |
|
|
direction, |
|
|
expansion=2, |
|
|
qkv_block_size=4, |
|
|
proj_bias=False, |
|
|
conv_bias=True, |
|
|
kernel_size=4, |
|
|
): |
|
|
super().__init__() |
|
|
if dim % qkv_block_size != 0: |
|
|
qkv_block_size=2 |
|
|
|
|
|
self.dim = dim |
|
|
self.direction = direction |
|
|
self.expansion = expansion |
|
|
self.qkv_block_size = qkv_block_size |
|
|
self.proj_bias = proj_bias |
|
|
self.conv_bias = conv_bias |
|
|
self.kernel_size = kernel_size |
|
|
|
|
|
inner_dim = expansion * dim |
|
|
num_heads = inner_dim // qkv_block_size |
|
|
self.proj_up = nn.Linear( |
|
|
in_features=dim, |
|
|
out_features=2 * inner_dim, |
|
|
bias=proj_bias, |
|
|
) |
|
|
self.q_proj = LinearHeadwiseExpand( |
|
|
dim=inner_dim, |
|
|
num_heads=num_heads, |
|
|
bias=proj_bias, |
|
|
) |
|
|
self.k_proj = LinearHeadwiseExpand( |
|
|
dim=inner_dim, |
|
|
num_heads=num_heads, |
|
|
bias=proj_bias, |
|
|
) |
|
|
self.v_proj = LinearHeadwiseExpand( |
|
|
dim=inner_dim, |
|
|
num_heads=num_heads, |
|
|
bias=proj_bias, |
|
|
) |
|
|
|
|
|
self.conv1d = CausalConv1d( |
|
|
dim=inner_dim, |
|
|
kernel_size=kernel_size, |
|
|
bias=conv_bias, |
|
|
) |
|
|
self.mlstm_cell = MatrixLSTMCell( |
|
|
dim=inner_dim, |
|
|
num_heads=qkv_block_size, |
|
|
) |
|
|
self.learnable_skip = nn.Parameter(torch.ones(inner_dim)) |
|
|
|
|
|
self.proj_down = nn.Linear( |
|
|
in_features=inner_dim, |
|
|
out_features=dim, |
|
|
bias=proj_bias, |
|
|
) |
|
|
self.reset_parameters() |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
B, S, _ = x.shape |
|
|
|
|
|
|
|
|
if self.direction == SequenceTraversal.ROWWISE_FROM_TOP_LEFT: |
|
|
pass |
|
|
elif self.direction == SequenceTraversal.ROWWISE_FROM_BOT_RIGHT: |
|
|
x = x.flip(dims=[1]) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
x_inner = self.proj_up(x) |
|
|
x_mlstm, z = torch.chunk(x_inner, chunks=2, dim=-1) |
|
|
|
|
|
|
|
|
x_mlstm_conv = self.conv1d(x_mlstm) |
|
|
x_mlstm_conv_act = F.silu(x_mlstm_conv) |
|
|
q = self.q_proj(x_mlstm_conv_act) |
|
|
k = self.k_proj(x_mlstm_conv_act) |
|
|
v = self.v_proj(x_mlstm) |
|
|
h_tilde_state = self.mlstm_cell(q=q, k=k, v=v) |
|
|
h_tilde_state_skip = h_tilde_state + (self.learnable_skip * x_mlstm_conv_act) |
|
|
|
|
|
|
|
|
h_state = h_tilde_state_skip * F.silu(z) |
|
|
|
|
|
|
|
|
x = self.proj_down(h_state) |
|
|
|
|
|
|
|
|
if self.direction == SequenceTraversal.ROWWISE_FROM_TOP_LEFT: |
|
|
pass |
|
|
elif self.direction == SequenceTraversal.ROWWISE_FROM_BOT_RIGHT: |
|
|
x = x.flip(dims=[1]) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
return x |
|
|
|
|
|
def reset_parameters(self): |
|
|
|
|
|
small_init_(self.proj_up.weight, dim=self.dim) |
|
|
if self.proj_up.bias is not None: |
|
|
nn.init.zeros_(self.proj_up.bias) |
|
|
|
|
|
wang_init_(self.proj_down.weight, dim=self.dim, num_blocks=1) |
|
|
if self.proj_down.bias is not None: |
|
|
nn.init.zeros_(self.proj_down.bias) |
|
|
|
|
|
nn.init.ones_(self.learnable_skip) |
|
|
|
|
|
def _init_qkv_proj(qkv_proj: LinearHeadwiseExpand): |
|
|
|
|
|
small_init_(qkv_proj.weight, dim=self.dim) |
|
|
if qkv_proj.bias is not None: |
|
|
nn.init.zeros_(qkv_proj.bias) |
|
|
|
|
|
_init_qkv_proj(self.q_proj) |
|
|
_init_qkv_proj(self.k_proj) |
|
|
_init_qkv_proj(self.v_proj) |
|
|
|
|
|
self.mlstm_cell.reset_parameters() |
|
|
|
|
|
|
|
|
class ViLBlock(nn.Module): |
|
|
def __init__(self, dim, direction, drop_path=0.0, norm_bias=False): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.direction = direction |
|
|
self.drop_path = drop_path |
|
|
self.norm_bias = norm_bias |
|
|
|
|
|
self.drop_path = DropPath(drop_prob=drop_path) |
|
|
self.norm = LayerNorm(ndim=dim, weight=True, bias=norm_bias) |
|
|
self.layer = ViLLayer(dim=dim, direction=direction) |
|
|
|
|
|
self.reset_parameters() |
|
|
|
|
|
def _forward_path(self, x): |
|
|
x = self.norm(x) |
|
|
x = self.layer(x) |
|
|
return x |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = self.drop_path(x, self._forward_path) |
|
|
|
|
|
return x |
|
|
|
|
|
def reset_parameters(self): |
|
|
self.layer.reset_parameters() |
|
|
self.norm.reset_parameters() |
|
|
|
|
|
|
|
|
class VisionLSTM(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim=192, |
|
|
input_shape=(3, 224, 224), |
|
|
patch_size=16, |
|
|
depth=24, |
|
|
output_shape=(1000,), |
|
|
mode="classifier", |
|
|
pooling="bilateral_avg", |
|
|
drop_path_rate=0.0, |
|
|
stride=None, |
|
|
alternation="bidirectional", |
|
|
drop_path_decay=False, |
|
|
legacy_norm=False, |
|
|
): |
|
|
super().__init__() |
|
|
self.input_shape = input_shape |
|
|
self.output_shape = output_shape |
|
|
ndim = len(self.input_shape) - 1 |
|
|
self.patch_size = to_ntuple(patch_size, n=ndim) |
|
|
self.dim = dim |
|
|
self.depth = depth |
|
|
self.stride = stride |
|
|
self.mode = mode |
|
|
self.pooling = pooling |
|
|
self.alternation = alternation |
|
|
self.drop_path_rate = drop_path_rate |
|
|
self.drop_path_decay = drop_path_decay |
|
|
|
|
|
|
|
|
self.patch_embed = VitPatchEmbed( |
|
|
dim=dim, |
|
|
stride=stride, |
|
|
num_channels=self.input_shape[0], |
|
|
resolution=self.input_shape[1:], |
|
|
patch_size=self.patch_size, |
|
|
) |
|
|
|
|
|
|
|
|
self.pos_embed = VitPosEmbed2d(seqlens=self.patch_embed.seqlens, dim=dim) |
|
|
|
|
|
|
|
|
if drop_path_decay and drop_path_rate > 0.: |
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] |
|
|
else: |
|
|
dpr = [drop_path_rate] * depth |
|
|
|
|
|
|
|
|
directions = [] |
|
|
if alternation == "bidirectional": |
|
|
for i in range(depth): |
|
|
if i % 2 == 0: |
|
|
directions.append(SequenceTraversal.ROWWISE_FROM_TOP_LEFT) |
|
|
else: |
|
|
directions.append(SequenceTraversal.ROWWISE_FROM_BOT_RIGHT) |
|
|
else: |
|
|
raise NotImplementedError(f"invalid alternation '{alternation}'") |
|
|
|
|
|
|
|
|
self.blocks = nn.ModuleList( |
|
|
[ |
|
|
ViLBlock( |
|
|
dim=dim, |
|
|
drop_path=dpr[i], |
|
|
direction=directions[i], |
|
|
) |
|
|
for i in range(depth) |
|
|
] |
|
|
) |
|
|
|
|
|
if legacy_norm: |
|
|
self.legacy_norm = LayerNorm(dim, bias=False) |
|
|
else: |
|
|
self.legacy_norm = nn.Identity() |
|
|
self.norm = nn.LayerNorm(dim, eps=1e-6) |
|
|
|
|
|
|
|
|
if mode is None: |
|
|
|
|
|
assert self.output_shape is None |
|
|
assert self.pooling is None |
|
|
self.head = None |
|
|
self.output_shape = (self.patch_embed.num_patches, dim) |
|
|
elif mode == "classifier": |
|
|
|
|
|
assert self.output_shape is not None and len(self.output_shape) == 1, \ |
|
|
f"define number of classes via output_shape=(num_classes,) (e.g. output_shape=(1000,) for ImageNet-1K" |
|
|
self.head = nn.Linear(dim, self.output_shape[0]) |
|
|
|
|
|
nn.init.trunc_normal_(self.head.weight, std=2e-5) |
|
|
nn.init.zeros_(self.head.bias) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
def load_state_dict(self, state_dict, strict=True): |
|
|
|
|
|
old_pos_embed = state_dict["pos_embed.embed"] |
|
|
if old_pos_embed.shape != self.pos_embed.embed.shape: |
|
|
state_dict["pos_embed.embed"] = interpolate_sincos(embed=old_pos_embed, seqlens=self.pos_embed.seqlens) |
|
|
return super().load_state_dict(state_dict=state_dict, strict=strict) |
|
|
|
|
|
@torch.jit.ignore |
|
|
def no_weight_decay(self): |
|
|
return {"pos_embed.embed"} |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x = self.patch_embed(x) |
|
|
|
|
|
x = self.pos_embed(x) |
|
|
|
|
|
|
|
|
x = einops.rearrange(x, "b ... d -> b (...) d") |
|
|
|
|
|
|
|
|
for block in self.blocks: |
|
|
x = block(x) |
|
|
x = self.legacy_norm(x) |
|
|
|
|
|
|
|
|
if self.pooling is None: |
|
|
x = self.norm(x) |
|
|
elif self.pooling == "bilateral_avg": |
|
|
|
|
|
x = (x[:, 0] + x[:, -1]) / 2 |
|
|
x = self.norm(x) |
|
|
else: |
|
|
raise NotImplementedError(f"pooling '{self.pooling}' is not implemented") |
|
|
|
|
|
|
|
|
if self.head is not None: |
|
|
x = self.head(x) |
|
|
|
|
|
return x |