| | |
| | from functools import partial |
| | from typing import Tuple, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from timm.models.helpers import checkpoint_seq |
| | from timm.models.vision_transformer import Block, Mlp, VisionTransformer |
| |
|
| | from masking import transformer_random_masking |
| | from vit import channel_agnostic_vit |
| |
|
| | |
| | |
| | |
| | |
| |
|
| |
|
| | class SelfStandardize(nn.Module): |
| | def __init__(self) -> None: |
| | super().__init__() |
| | self.self_standardize = nn.LazyInstanceNorm2d( |
| | affine=False, track_running_stats=False |
| | ) |
| |
|
| | def forward(self, pixels: torch.Tensor) -> torch.Tensor: |
| | x = pixels.float() / 255.0 |
| | return self.self_standardize(x) |
| |
|
| |
|
| | class MAEEncoder(nn.Module): |
| | def __init__( |
| | self, |
| | vit_backbone: VisionTransformer, |
| | max_in_chans: int = 6, |
| | channel_agnostic: bool = False, |
| | ) -> None: |
| | super().__init__() |
| | if channel_agnostic: |
| | self.vit_backbone = channel_agnostic_vit( |
| | vit_backbone, max_in_chans=max_in_chans |
| | ) |
| | else: |
| | self.vit_backbone = vit_backbone |
| | self.max_in_chans = max_in_chans |
| | self.channel_agnostic = channel_agnostic |
| |
|
| | @property |
| | def embed_dim(self) -> int: |
| | return int(self.vit_backbone.embed_dim) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x = self.vit_backbone.forward_features(x) |
| | x = self.vit_backbone.forward_head(x) |
| | return x |
| |
|
| | def forward_masked( |
| | self, |
| | x: torch.Tensor, |
| | mask_ratio: float, |
| | constant_noise: Union[torch.Tensor, None] = None, |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | x = self.vit_backbone.patch_embed(x) |
| | x = self.vit_backbone._pos_embed(x) |
| | x_ = x[:, 1:, :] |
| | x_, mask, ind_restore = transformer_random_masking( |
| | x_, mask_ratio, constant_noise |
| | ) |
| | x = torch.cat([x[:, :1, :], x_], dim=1) |
| | x = self.vit_backbone.norm_pre(x) |
| |
|
| | if self.vit_backbone.grad_checkpointing and not torch.jit.is_scripting(): |
| | x = checkpoint_seq(self.vit_backbone.blocks, x) |
| | else: |
| | x = self.vit_backbone.blocks(x) |
| | x = self.vit_backbone.norm(x) |
| | return x, mask, ind_restore |
| |
|
| |
|
| | class MAEDecoder(nn.Module): |
| | def __init__( |
| | self, |
| | embed_dim: int = 512, |
| | depth: int = 8, |
| | num_heads: int = 16, |
| | mlp_ratio: float = 4, |
| | qkv_bias: bool = True, |
| | norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), |
| | ) -> None: |
| | super().__init__() |
| | self.embed_dim = embed_dim |
| | self.pos_embeddings = None |
| | self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| | self.blocks = nn.Sequential( |
| | *[ |
| | Block( |
| | embed_dim, |
| | num_heads, |
| | mlp_ratio, |
| | qkv_bias=qkv_bias, |
| | norm_layer=norm_layer, |
| | ) |
| | for i in range(depth) |
| | ] |
| | ) |
| | self.norm = norm_layer(embed_dim) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x = x + self.pos_embeddings |
| | x = self.blocks(x) |
| | x = self.norm(x) |
| | return x |
| |
|
| | def forward_masked( |
| | self, x: torch.Tensor, ind_restore: torch.Tensor |
| | ) -> torch.Tensor: |
| | mask_tokens = self.mask_token.repeat( |
| | x.shape[0], ind_restore.shape[1] + 1 - x.shape[1], 1 |
| | ) |
| | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) |
| | x_ = torch.gather( |
| | x_, dim=1, index=ind_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]) |
| | ) |
| | x = torch.cat([x[:, :1, :], x_], dim=1) |
| |
|
| | x = x + self.pos_embeddings |
| | x = self.blocks(x) |
| | x = self.norm(x) |
| | return x |
| |
|
| |
|
| | class CrossAttention(nn.Module): |
| | def __init__( |
| | self, embed_dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0 |
| | ): |
| | super().__init__() |
| | self.num_heads = num_heads |
| | head_dim = embed_dim // num_heads |
| | self.scale = head_dim**-0.5 |
| |
|
| | self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) |
| | self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias) |
| |
|
| | self.attn_drop = nn.Dropout(attn_drop) |
| | self.proj = nn.Linear(embed_dim, embed_dim) |
| | self.proj_drop = nn.Dropout(proj_drop) |
| |
|
| | def forward(self, x, context): |
| | B, N, C = x.shape |
| | _, M, _ = context.shape |
| |
|
| | q = ( |
| | self.q(x) |
| | .reshape(B, N, self.num_heads, C // self.num_heads) |
| | .permute(0, 2, 1, 3) |
| | ) |
| | kv = ( |
| | self.kv(context) |
| | .reshape(B, M, 2, self.num_heads, C // self.num_heads) |
| | .permute(2, 0, 3, 1, 4) |
| | ) |
| | k, v = kv[0], kv[1] |
| |
|
| | attn = (q @ k.transpose(-2, -1)) * self.scale |
| | attn = attn.softmax(dim=-1) |
| | attn = self.attn_drop(attn) |
| |
|
| | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) |
| | x = self.proj(x) |
| | x = self.proj_drop(x) |
| | return x |
| |
|
| |
|
| | class CAMAEDecoder(nn.Module): |
| | def __init__( |
| | self, |
| | num_modalities: int = 6, |
| | tokens_per_modality: int = 256, |
| | embed_dim: int = 256, |
| | depth: int = 2, |
| | num_heads: int = 16, |
| | mlp_ratio: float = 4, |
| | qkv_bias: bool = True, |
| | norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), |
| | ) -> None: |
| | super().__init__() |
| | self.num_modalities = num_modalities |
| | self.tokens_per_modality = tokens_per_modality |
| | self.embed_dim = embed_dim |
| | self.pos_embeddings = None |
| | self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| | self.placeholder = nn.Parameter( |
| | torch.zeros(1, 1, embed_dim), requires_grad=False |
| | ) |
| | self.modality_tokens = nn.ParameterList( |
| | [ |
| | nn.Parameter(torch.zeros(1, 1, self.embed_dim)) |
| | for modality in range(self.num_modalities) |
| | ] |
| | ) |
| |
|
| | self.cross_attention = CrossAttention(embed_dim=self.embed_dim) |
| | self.mlp = Mlp(self.embed_dim, hidden_features=int(self.embed_dim * mlp_ratio)) |
| |
|
| | self.decoders = nn.ModuleList( |
| | [ |
| | nn.Sequential( |
| | *[ |
| | Block( |
| | embed_dim, |
| | num_heads, |
| | mlp_ratio, |
| | qkv_bias=qkv_bias, |
| | norm_layer=norm_layer, |
| | ) |
| | for i in range(depth) |
| | ] |
| | ) |
| | for modality in range(self.num_modalities) |
| | ] |
| | ) |
| | |
| | self.context_norm = norm_layer(embed_dim) |
| | self.query_norm = norm_layer(embed_dim) |
| | self.out_norm = norm_layer(embed_dim) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x_m_s = [] |
| |
|
| | modality_tokens_concat = torch.cat( |
| | [ |
| | self.placeholder, |
| | ] |
| | + [ |
| | m_t.repeat(1, self.tokens_per_modality, 1) |
| | for m_t in self.modality_tokens |
| | ], |
| | dim=1, |
| | ) |
| |
|
| | x = ( |
| | x + self.pos_embeddings + modality_tokens_concat |
| | ) |
| | x_ = x[:, 1:, :] |
| | for m, decoder in enumerate( |
| | self.decoders |
| | ): |
| | x_m = x_[ |
| | :, m * self.tokens_per_modality : (m + 1) * self.tokens_per_modality, : |
| | ] |
| | x_m = self.cross_attention(self.query_norm(x_m), self.context_norm(x_)) |
| | x_m = x_m + self.mlp(self.out_norm(x_m)) |
| | x_m = decoder(x_m) |
| | x_m_s.append(x_m) |
| | x_m_s = torch.cat(x_m_s, dim=1) |
| | |
| | x_m_s = torch.cat([x[:, :1, :], x_m_s], dim=1) |
| |
|
| | return x_m_s |
| |
|
| | def forward_masked( |
| | self, x: torch.Tensor, ind_restore: torch.Tensor |
| | ) -> torch.Tensor: |
| | mask_tokens = self.mask_token.repeat( |
| | x.shape[0], ind_restore.shape[1] + 1 - x.shape[1], 1 |
| | ) |
| | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) |
| | x_ = torch.gather( |
| | x_, dim=1, index=ind_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]) |
| | ) |
| | x = torch.cat([x[:, :1, :], x_], dim=1) |
| | x = self.forward(x) |
| | return x |
| |
|