| | from typing import List, Iterable |
| | import torch |
| | import torch.nn.functional as F |
| |
|
| |
|
| | |
| | def pad_divide_by(in_img: torch.Tensor, d: int) -> (torch.Tensor, Iterable[int]): |
| | h, w = in_img.shape[-2:] |
| |
|
| | if h % d > 0: |
| | new_h = h + d - h % d |
| | else: |
| | new_h = h |
| | if w % d > 0: |
| | new_w = w + d - w % d |
| | else: |
| | new_w = w |
| | lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2) |
| | lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2) |
| | pad_array = (int(lw), int(uw), int(lh), int(uh)) |
| | out = F.pad(in_img, pad_array) |
| | return out, pad_array |
| |
|
| |
|
| | def unpad(img: torch.Tensor, pad: Iterable[int]) -> torch.Tensor: |
| | if len(img.shape) == 4: |
| | if pad[2] + pad[3] > 0: |
| | img = img[:, :, pad[2]:-pad[3], :] |
| | if pad[0] + pad[1] > 0: |
| | img = img[:, :, :, pad[0]:-pad[1]] |
| | elif len(img.shape) == 3: |
| | if pad[2] + pad[3] > 0: |
| | img = img[:, pad[2]:-pad[3], :] |
| | if pad[0] + pad[1] > 0: |
| | img = img[:, :, pad[0]:-pad[1]] |
| | elif len(img.shape) == 5: |
| | if pad[2] + pad[3] > 0: |
| | img = img[:, :, :, pad[2]:-pad[3], :] |
| | if pad[0] + pad[1] > 0: |
| | img = img[:, :, :, :, pad[0]:-pad[1]] |
| | else: |
| | raise NotImplementedError |
| | return img |
| |
|
| |
|
| | |
| | def aggregate(prob: torch.Tensor, dim: int) -> torch.Tensor: |
| | with torch.amp.autocast("cuda"): |
| | prob = prob.float() |
| | new_prob = torch.cat([torch.prod(1 - prob, dim=dim, keepdim=True), prob], |
| | dim).clamp(1e-7, 1 - 1e-7) |
| | logits = torch.log((new_prob / (1 - new_prob))) |
| |
|
| | return logits |
| |
|
| |
|
| | |
| | def cls_to_one_hot(cls_gt: torch.Tensor, num_objects: int) -> torch.Tensor: |
| | |
| | B, _, H, W = cls_gt.shape |
| | one_hot = torch.zeros(B, num_objects + 1, H, W, device=cls_gt.device).scatter_(1, cls_gt, 1) |
| | return one_hot |