| from typing import Dict, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from transformers import PretrainedConfig, PreTrainedModel |
| from transformers.utils import cached_file |
|
|
| from .loss import FourierLoss |
| from .normalizer import Normalizer |
| from .mae_modules import CAMAEDecoder, MAEDecoder, MAEEncoder |
| from .mae_utils import flatten_images |
| from .vit import ( |
| generate_2d_sincos_pos_embeddings, |
| sincos_positional_encoding_vit, |
| vit_small_patch16_256, |
| ) |
|
|
| TensorDict = Dict[str, torch.Tensor] |
|
|
|
|
| class MAEConfig(PretrainedConfig): |
| model_type = "MAE" |
|
|
| def __init__( |
| self, |
| mask_ratio=0.75, |
| encoder=None, |
| decoder=None, |
| loss=None, |
| optimizer=None, |
| input_norm=None, |
| fourier_loss=None, |
| fourier_loss_weight=0.0, |
| lr_scheduler=None, |
| use_MAE_weight_init=False, |
| crop_size=-1, |
| mask_fourier_loss=True, |
| return_channelwise_embeddings=False, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.mask_ratio = mask_ratio |
| self.encoder = encoder |
| self.decoder = decoder |
| self.loss = loss |
| self.optimizer = optimizer |
| self.input_norm = input_norm |
| self.fourier_loss = fourier_loss |
| self.fourier_loss_weight = fourier_loss_weight |
| self.lr_scheduler = lr_scheduler |
| self.use_MAE_weight_init = use_MAE_weight_init |
| self.crop_size = crop_size |
| self.mask_fourier_loss = mask_fourier_loss |
| self.return_channelwise_embeddings = return_channelwise_embeddings |
|
|
|
|
| class MAEModel(PreTrainedModel): |
| config_class = MAEConfig |
|
|
| |
| TOTAL_LOSS = "loss" |
| RECON_LOSS = "reconstruction_loss" |
| FOURIER_LOSS = "fourier_loss" |
|
|
| def __init__(self, config: MAEConfig): |
| super().__init__(config) |
|
|
| self.mask_ratio = config.mask_ratio |
|
|
| |
| self.encoder = MAEEncoder( |
| vit_backbone=sincos_positional_encoding_vit( |
| vit_backbone=vit_small_patch16_256(global_pool="avg") |
| ), |
| max_in_chans=11, |
| channel_agnostic=True, |
| ) |
| self.decoder = CAMAEDecoder( |
| depth=8, |
| embed_dim=512, |
| mlp_ratio=4, |
| norm_layer=nn.LayerNorm, |
| num_heads=16, |
| num_modalities=6, |
| qkv_bias=True, |
| tokens_per_modality=256, |
| ) |
| self.input_norm = torch.nn.Sequential( |
| Normalizer(), |
| nn.InstanceNorm2d(None, affine=False, track_running_stats=False), |
| ) |
|
|
| self.fourier_loss_weight = config.fourier_loss_weight |
| self.mask_fourier_loss = config.mask_fourier_loss |
| self.return_channelwise_embeddings = config.return_channelwise_embeddings |
| self.tokens_per_channel = 256 |
|
|
| |
| self.loss = torch.nn.MSELoss(reduction="none") |
|
|
| self.fourier_loss = FourierLoss(num_multimodal_modalities=6) |
| if self.fourier_loss_weight > 0 and self.fourier_loss is None: |
| raise ValueError( |
| "FourierLoss weight is activated but no fourier_loss was defined in constructor" |
| ) |
| elif self.fourier_loss_weight >= 1: |
| raise ValueError( |
| "FourierLoss weight is too large to do mixing factor, weight should be < 1" |
| ) |
|
|
| self.patch_size = int(self.encoder.vit_backbone.patch_embed.patch_size[0]) |
|
|
| |
| self.encoder_decoder_proj = nn.Linear( |
| self.encoder.embed_dim, self.decoder.embed_dim, bias=True |
| ) |
|
|
| self.decoder_pred = nn.Linear( |
| self.decoder.embed_dim, |
| self.patch_size**2 |
| * (1 if self.encoder.channel_agnostic else self.in_chans), |
| bias=True, |
| ) |
|
|
| |
| self.decoder.pos_embeddings = generate_2d_sincos_pos_embeddings( |
| self.decoder.embed_dim, |
| length=self.encoder.vit_backbone.patch_embed.grid_size[0], |
| use_class_token=self.encoder.vit_backbone.cls_token is not None, |
| num_modality=( |
| self.decoder.num_modalities if self.encoder.channel_agnostic else 1 |
| ), |
| ) |
|
|
| if config.use_MAE_weight_init: |
| w = self.encoder.vit_backbone.patch_embed.proj.weight.data |
| torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
| torch.nn.init.normal_(self.encoder.vit_backbone.cls_token, std=0.02) |
| torch.nn.init.normal_(self.decoder.mask_token, std=0.02) |
|
|
| self.apply(self._MAE_init_weights) |
|
|
| def setup(self, stage: str) -> None: |
| super().setup(stage) |
|
|
| def _MAE_init_weights(self, m): |
| if isinstance(m, nn.Linear): |
| torch.nn.init.xavier_uniform_(m.weight) |
| if isinstance(m, nn.Linear) and m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.constant_(m.bias, 0) |
| nn.init.constant_(m.weight, 1.0) |
|
|
| @staticmethod |
| def decode_to_reconstruction( |
| encoder_latent: torch.Tensor, |
| ind_restore: torch.Tensor, |
| proj: torch.nn.Module, |
| decoder: MAEDecoder | CAMAEDecoder, |
| pred: torch.nn.Module, |
| ) -> torch.Tensor: |
| """Feed forward the encoder latent through the decoders necessary projections and transformations.""" |
| decoder_latent_projection = proj( |
| encoder_latent |
| ) |
| decoder_tokens = decoder.forward_masked( |
| decoder_latent_projection, ind_restore |
| ) |
| predicted_reconstruction = pred( |
| decoder_tokens |
| ) |
| return predicted_reconstruction[:, 1:, :] |
|
|
| def forward( |
| self, imgs: torch.Tensor, constant_noise: Union[torch.Tensor, None] = None |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| imgs = self.input_norm(imgs) |
| latent, mask, ind_restore = self.encoder.forward_masked( |
| imgs, self.mask_ratio, constant_noise |
| ) |
| reconstruction = self.decode_to_reconstruction( |
| latent, |
| ind_restore, |
| self.encoder_decoder_proj, |
| self.decoder, |
| self.decoder_pred, |
| ) |
| return latent, reconstruction, mask |
|
|
| def compute_MAE_loss( |
| self, |
| reconstruction: torch.Tensor, |
| img: torch.Tensor, |
| mask: torch.Tensor, |
| ) -> Tuple[torch.Tensor, Dict[str, float]]: |
| """Computes final loss and returns specific values of component losses for metric reporting.""" |
| loss_dict = {} |
| img = self.input_norm(img) |
| target_flattened = flatten_images( |
| img, |
| patch_size=self.patch_size, |
| channel_agnostic=self.encoder.channel_agnostic, |
| ) |
|
|
| loss: torch.Tensor = self.loss( |
| reconstruction, target_flattened |
| ) |
| loss = loss.mean( |
| dim=-1 |
| ) |
| loss = (loss * mask).sum() / mask.sum() |
| loss_dict[self.RECON_LOSS] = loss.item() |
|
|
| |
| if self.fourier_loss_weight > 0: |
| floss: torch.Tensor = self.fourier_loss(reconstruction, target_flattened) |
| if not self.mask_fourier_loss: |
| floss = floss.mean() |
| else: |
| floss = floss.mean(dim=-1) |
| floss = (floss * mask).sum() / mask.sum() |
|
|
| loss_dict[self.FOURIER_LOSS] = floss.item() |
|
|
| |
| if self.fourier_loss_weight > 0: |
| loss = (1 - self.fourier_loss_weight) * loss + ( |
| self.fourier_loss_weight * floss |
| ) |
| return loss, loss_dict |
|
|
| def training_step(self, batch: TensorDict, batch_idx: int) -> TensorDict: |
| img = batch["pixels"] |
| latent, reconstruction, mask = self(img.clone()) |
| full_loss, loss_dict = self.compute_MAE_loss(reconstruction, img.float(), mask) |
| return { |
| "loss": full_loss, |
| **loss_dict, |
| } |
|
|
| def validation_step(self, batch: TensorDict, batch_idx: int) -> TensorDict: |
| return self.training_step(batch, batch_idx) |
|
|
| def update_metrics(self, outputs: TensorDict, batch: TensorDict) -> None: |
| self.metrics["lr"].update(value=self.lr_scheduler.get_last_lr()) |
| for key, value in outputs.items(): |
| if key.endswith("loss"): |
| self.metrics[key].update(value) |
|
|
| def on_validation_batch_end( |
| self, |
| outputs: TensorDict, |
| batch: TensorDict, |
| batch_idx: int, |
| dataloader_idx: int = 0, |
| ) -> None: |
| super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx) |
|
|
| def predict(self, imgs: torch.Tensor) -> torch.Tensor: |
| imgs = self.input_norm(imgs) |
| X = self.encoder.vit_backbone.forward_features( |
| imgs |
| ) |
| if self.return_channelwise_embeddings: |
| N, _, d = X.shape |
| num_channels = imgs.shape[1] |
| X_reshaped = X[:, 1:, :].view(N, num_channels, self.tokens_per_channel, d) |
| pooled_segments = X_reshaped.mean( |
| dim=2 |
| ) |
| latent = pooled_segments.view(N, num_channels * d).contiguous() |
| else: |
| latent = X[:, 1:, :].mean(dim=1) |
| return latent |
|
|
| def save_pretrained(self, save_directory: str, **kwargs): |
| filename = kwargs.pop("filename", "model.safetensors") |
| modelpath = f"{save_directory}/{filename}" |
| self.config.save_pretrained(save_directory) |
| torch.save({"state_dict": self.state_dict()}, modelpath) |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
| filename = kwargs.pop("filename", "model.safetensors") |
|
|
| config = MAEConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) |
| modelpath = cached_file(pretrained_model_name_or_path, filename=filename) |
| state_dict = torch.load(modelpath, map_location="cpu") |
| model = cls(config) |
| model.load_state_dict(state_dict["state_dict"]) |
| return model |
|
|