| import torch |
| import numpy as np |
| from typing import Tuple, Optional |
|
|
|
|
| class MaskingGenerator: |
| """ |
| Generates masks for the input patches. |
| |
| Args: |
| input_size (tuple[int, int]): Input image size (H, W). |
| patch_size (tuple[int, int]): Patch size (H, W). |
| mask_ratio (tuple[float, float]): Range of mask ratio (min, max). |
| """ |
|
|
| def __init__( |
| self, |
| input_size: Tuple[int, int] = (128, 256), |
| patch_size: Tuple[int, int] = (16, 16), |
| mask_ratio: Tuple[float, float] = (0.4, 0.6), |
| ): |
| self.height, self.width = input_size |
| self.patch_h, self.patch_w = patch_size |
| self.num_patches_h = self.height // self.patch_h |
| self.num_patches_w = self.width // self.patch_w |
| self.num_patches = self.num_patches_h * self.num_patches_w |
| self.mask_ratio = mask_ratio |
|
|
| def __call__( |
| self, |
| batch_size: int, |
| device: torch.device = torch.device("cpu"), |
| grid_size: Optional[Tuple[int, int]] = None, |
| ) -> torch.Tensor: |
| """ |
| Generate masks for a batch. |
| |
| Args: |
| batch_size (int): Batch size. |
| device (torch.device): Device to place the masks on. |
| grid_size (Optional[Tuple[int, int]]): Grid size (H, W) if different from init. |
| |
| Returns: |
| torch.Tensor: Masks [B, N] (boolean, True=masked). |
| """ |
| masks = [] |
| for _ in range(batch_size): |
| mask = self._generate_mask(grid_size) |
| masks.append(mask) |
|
|
| return torch.stack(masks).to(device) |
|
|
| def _generate_mask( |
| self, grid_size: Optional[Tuple[int, int]] = None |
| ) -> torch.Tensor: |
| """ |
| Generate a single mask. |
| """ |
| if grid_size is not None: |
| num_patches_h, num_patches_w = grid_size |
| num_patches = num_patches_h * num_patches_w |
| else: |
| num_patches = self.num_patches |
|
|
| mask = torch.zeros(num_patches, dtype=torch.bool) |
|
|
| target_masked = int(num_patches * np.random.uniform(*self.mask_ratio)) |
|
|
| |
| if target_masked > 0: |
| perm = torch.randperm(num_patches) |
| mask_indices = perm[:target_masked] |
| mask[mask_indices] = True |
|
|
| return mask |
|
|