File size: 2,326 Bytes
eca55dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import numpy as np
from typing import Tuple, Optional


class MaskingGenerator:
    """
    Generates masks for the input patches.

    Args:
        input_size (tuple[int, int]): Input image size (H, W).
        patch_size (tuple[int, int]): Patch size (H, W).
        mask_ratio (tuple[float, float]): Range of mask ratio (min, max).
    """

    def __init__(
        self,
        input_size: Tuple[int, int] = (128, 256),
        patch_size: Tuple[int, int] = (16, 16),
        mask_ratio: Tuple[float, float] = (0.4, 0.6),
    ):
        self.height, self.width = input_size
        self.patch_h, self.patch_w = patch_size
        self.num_patches_h = self.height // self.patch_h
        self.num_patches_w = self.width // self.patch_w
        self.num_patches = self.num_patches_h * self.num_patches_w
        self.mask_ratio = mask_ratio

    def __call__(
        self,
        batch_size: int,
        device: torch.device = torch.device("cpu"),
        grid_size: Optional[Tuple[int, int]] = None,
    ) -> torch.Tensor:
        """
        Generate masks for a batch.

        Args:
            batch_size (int): Batch size.
            device (torch.device): Device to place the masks on.
            grid_size (Optional[Tuple[int, int]]): Grid size (H, W) if different from init.

        Returns:
            torch.Tensor: Masks [B, N] (boolean, True=masked).
        """
        masks = []
        for _ in range(batch_size):
            mask = self._generate_mask(grid_size)
            masks.append(mask)

        return torch.stack(masks).to(device)

    def _generate_mask(
        self, grid_size: Optional[Tuple[int, int]] = None
    ) -> torch.Tensor:
        """
        Generate a single mask.
        """
        if grid_size is not None:
            num_patches_h, num_patches_w = grid_size
            num_patches = num_patches_h * num_patches_w
        else:
            num_patches = self.num_patches

        mask = torch.zeros(num_patches, dtype=torch.bool)

        target_masked = int(num_patches * np.random.uniform(*self.mask_ratio))

        # Random Permutation Masking
        if target_masked > 0:
            perm = torch.randperm(num_patches)
            mask_indices = perm[:target_masked]
            mask[mask_indices] = True

        return mask  # Already flattened