| | |
| | from typing import Tuple, Union |
| |
|
| | import torch |
| |
|
| |
|
| | def transformer_random_masking( |
| | x: torch.Tensor, mask_ratio: float, constant_noise: Union[torch.Tensor, None] = None |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """ |
| | Random mask patches per sample |
| | |
| | Parameters |
| | ---------- |
| | x : token tensor (N, L, D) |
| | mask_ratio: float - ratio of image to mask |
| | constant_noise: None, if provided should be a tensor of shape (N, L) to produce consistent masks |
| | |
| | Returns |
| | ------- |
| | x_masked : sub-sampled version of x ( int(mask_ratio * N), L, D) |
| | mask : binary mask indicated masked tokens (1 where masked) (N, L) |
| | ind_restore : locations of masked tokens, needed for decoder |
| | """ |
| |
|
| | N, L, D = x.shape |
| | len_keep = int(L * (1 - mask_ratio)) |
| |
|
| | |
| | if constant_noise is not None: |
| | noise = constant_noise |
| | else: |
| | noise = torch.rand(N, L, device=x.device) |
| |
|
| | shuffled_tokens = torch.argsort(noise, dim=1) |
| | ind_restore = torch.argsort(shuffled_tokens, dim=1) |
| |
|
| | |
| | tokens_to_keep = shuffled_tokens[:, :len_keep] |
| | x_masked = torch.gather( |
| | x, dim=1, index=tokens_to_keep.unsqueeze(-1).repeat(1, 1, D) |
| | ) |
| |
|
| | |
| | mask = torch.ones([N, L], device=x.device) |
| | mask[:, :len_keep] = 0 |
| | mask = torch.gather( |
| | mask, dim=1, index=ind_restore |
| | ) |
| |
|
| | return x_masked, mask, ind_restore |
| |
|