File size: 2,326 Bytes
eca55dc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | import torch
import numpy as np
from typing import Tuple, Optional
class MaskingGenerator:
"""
Generates masks for the input patches.
Args:
input_size (tuple[int, int]): Input image size (H, W).
patch_size (tuple[int, int]): Patch size (H, W).
mask_ratio (tuple[float, float]): Range of mask ratio (min, max).
"""
def __init__(
self,
input_size: Tuple[int, int] = (128, 256),
patch_size: Tuple[int, int] = (16, 16),
mask_ratio: Tuple[float, float] = (0.4, 0.6),
):
self.height, self.width = input_size
self.patch_h, self.patch_w = patch_size
self.num_patches_h = self.height // self.patch_h
self.num_patches_w = self.width // self.patch_w
self.num_patches = self.num_patches_h * self.num_patches_w
self.mask_ratio = mask_ratio
def __call__(
self,
batch_size: int,
device: torch.device = torch.device("cpu"),
grid_size: Optional[Tuple[int, int]] = None,
) -> torch.Tensor:
"""
Generate masks for a batch.
Args:
batch_size (int): Batch size.
device (torch.device): Device to place the masks on.
grid_size (Optional[Tuple[int, int]]): Grid size (H, W) if different from init.
Returns:
torch.Tensor: Masks [B, N] (boolean, True=masked).
"""
masks = []
for _ in range(batch_size):
mask = self._generate_mask(grid_size)
masks.append(mask)
return torch.stack(masks).to(device)
def _generate_mask(
self, grid_size: Optional[Tuple[int, int]] = None
) -> torch.Tensor:
"""
Generate a single mask.
"""
if grid_size is not None:
num_patches_h, num_patches_w = grid_size
num_patches = num_patches_h * num_patches_w
else:
num_patches = self.num_patches
mask = torch.zeros(num_patches, dtype=torch.bool)
target_masked = int(num_patches * np.random.uniform(*self.mask_ratio))
# Random Permutation Masking
if target_masked > 0:
perm = torch.randperm(num_patches)
mask_indices = perm[:target_masked]
mask[mask_indices] = True
return mask # Already flattened
|