| | |
| | |
| |
|
| | from collections import Counter |
| | import torch |
| | import torch.nn.functional as F |
| |
|
| |
|
| | def find_instance_center(center_heatmap, threshold=0.1, nms_kernel=3, top_k=None): |
| | """ |
| | Find the center points from the center heatmap. |
| | Args: |
| | center_heatmap: A Tensor of shape [1, H, W] of raw center heatmap output. |
| | threshold: A float, threshold applied to center heatmap score. |
| | nms_kernel: An integer, NMS max pooling kernel size. |
| | top_k: An integer, top k centers to keep. |
| | Returns: |
| | A Tensor of shape [K, 2] where K is the number of center points. The |
| | order of second dim is (y, x). |
| | """ |
| | |
| | center_heatmap = F.threshold(center_heatmap, threshold, -1) |
| |
|
| | |
| | nms_padding = (nms_kernel - 1) // 2 |
| | center_heatmap_max_pooled = F.max_pool2d( |
| | center_heatmap, kernel_size=nms_kernel, stride=1, padding=nms_padding |
| | ) |
| | center_heatmap[center_heatmap != center_heatmap_max_pooled] = -1 |
| |
|
| | |
| | center_heatmap = center_heatmap.squeeze() |
| | assert len(center_heatmap.size()) == 2, "Something is wrong with center heatmap dimension." |
| |
|
| | |
| | if top_k is None: |
| | return torch.nonzero(center_heatmap > 0) |
| | else: |
| | |
| | top_k_scores, _ = torch.topk(torch.flatten(center_heatmap), top_k) |
| | return torch.nonzero(center_heatmap > top_k_scores[-1].clamp_(min=0)) |
| |
|
| |
|
| | def group_pixels(center_points, offsets): |
| | """ |
| | Gives each pixel in the image an instance id. |
| | Args: |
| | center_points: A Tensor of shape [K, 2] where K is the number of center points. |
| | The order of second dim is (y, x). |
| | offsets: A Tensor of shape [2, H, W] of raw offset output. The order of |
| | second dim is (offset_y, offset_x). |
| | Returns: |
| | A Tensor of shape [1, H, W] with values in range [1, K], which represents |
| | the center this pixel belongs to. |
| | """ |
| | height, width = offsets.size()[1:] |
| |
|
| | |
| | |
| | y_coord, x_coord = torch.meshgrid( |
| | torch.arange(height, dtype=offsets.dtype, device=offsets.device), |
| | torch.arange(width, dtype=offsets.dtype, device=offsets.device), |
| | ) |
| | coord = torch.cat((y_coord.unsqueeze(0), x_coord.unsqueeze(0)), dim=0) |
| |
|
| | center_loc = coord + offsets |
| | center_loc = center_loc.flatten(1).T.unsqueeze_(0) |
| | center_points = center_points.unsqueeze(1) |
| |
|
| | |
| | distance = torch.norm(center_points - center_loc, dim=-1) |
| |
|
| | |
| | |
| | instance_id = torch.argmin(distance, dim=0).reshape((1, height, width)) + 1 |
| | return instance_id |
| |
|
| |
|
| | def get_instance_segmentation( |
| | sem_seg, center_heatmap, offsets, thing_seg, thing_ids, threshold=0.1, nms_kernel=3, top_k=None |
| | ): |
| | """ |
| | Post-processing for instance segmentation, gets class agnostic instance id. |
| | Args: |
| | sem_seg: A Tensor of shape [1, H, W], predicted semantic label. |
| | center_heatmap: A Tensor of shape [1, H, W] of raw center heatmap output. |
| | offsets: A Tensor of shape [2, H, W] of raw offset output. The order of |
| | second dim is (offset_y, offset_x). |
| | thing_seg: A Tensor of shape [1, H, W], predicted foreground mask, |
| | if not provided, inference from semantic prediction. |
| | thing_ids: A set of ids from contiguous category ids belonging |
| | to thing categories. |
| | threshold: A float, threshold applied to center heatmap score. |
| | nms_kernel: An integer, NMS max pooling kernel size. |
| | top_k: An integer, top k centers to keep. |
| | Returns: |
| | A Tensor of shape [1, H, W] with value 0 represent stuff (not instance) |
| | and other positive values represent different instances. |
| | A Tensor of shape [1, K, 2] where K is the number of center points. |
| | The order of second dim is (y, x). |
| | """ |
| | center_points = find_instance_center( |
| | center_heatmap, threshold=threshold, nms_kernel=nms_kernel, top_k=top_k |
| | ) |
| | if center_points.size(0) == 0: |
| | return torch.zeros_like(sem_seg), center_points.unsqueeze(0) |
| | ins_seg = group_pixels(center_points, offsets) |
| | return thing_seg * ins_seg, center_points.unsqueeze(0) |
| |
|
| |
|
| | def merge_semantic_and_instance( |
| | sem_seg, ins_seg, semantic_thing_seg, label_divisor, thing_ids, stuff_area, void_label |
| | ): |
| | """ |
| | Post-processing for panoptic segmentation, by merging semantic segmentation |
| | label and class agnostic instance segmentation label. |
| | Args: |
| | sem_seg: A Tensor of shape [1, H, W], predicted category id for each pixel. |
| | ins_seg: A Tensor of shape [1, H, W], predicted instance id for each pixel. |
| | semantic_thing_seg: A Tensor of shape [1, H, W], predicted foreground mask. |
| | label_divisor: An integer, used to convert panoptic id = |
| | semantic id * label_divisor + instance_id. |
| | thing_ids: Set, a set of ids from contiguous category ids belonging |
| | to thing categories. |
| | stuff_area: An integer, remove stuff whose area is less tan stuff_area. |
| | void_label: An integer, indicates the region has no confident prediction. |
| | Returns: |
| | A Tensor of shape [1, H, W]. |
| | """ |
| | |
| | pan_seg = torch.zeros_like(sem_seg) + void_label |
| | is_thing = (ins_seg > 0) & (semantic_thing_seg > 0) |
| |
|
| | |
| | class_id_tracker = Counter() |
| |
|
| | |
| | instance_ids = torch.unique(ins_seg) |
| | for ins_id in instance_ids: |
| | if ins_id == 0: |
| | continue |
| | |
| | thing_mask = (ins_seg == ins_id) & is_thing |
| | if torch.nonzero(thing_mask).size(0) == 0: |
| | continue |
| | class_id, _ = torch.mode(sem_seg[thing_mask].view(-1)) |
| | class_id_tracker[class_id.item()] += 1 |
| | new_ins_id = class_id_tracker[class_id.item()] |
| | pan_seg[thing_mask] = class_id * label_divisor + new_ins_id |
| |
|
| | |
| | class_ids = torch.unique(sem_seg) |
| | for class_id in class_ids: |
| | if class_id.item() in thing_ids: |
| | |
| | continue |
| | |
| | stuff_mask = (sem_seg == class_id) & (ins_seg == 0) |
| | if stuff_mask.sum().item() >= stuff_area: |
| | pan_seg[stuff_mask] = class_id * label_divisor |
| |
|
| | return pan_seg |
| |
|
| |
|
| | def get_panoptic_segmentation( |
| | sem_seg, |
| | center_heatmap, |
| | offsets, |
| | thing_ids, |
| | label_divisor, |
| | stuff_area, |
| | void_label, |
| | threshold=0.1, |
| | nms_kernel=7, |
| | top_k=200, |
| | foreground_mask=None, |
| | ): |
| | """ |
| | Post-processing for panoptic segmentation. |
| | Args: |
| | sem_seg: A Tensor of shape [1, H, W] of predicted semantic label. |
| | center_heatmap: A Tensor of shape [1, H, W] of raw center heatmap output. |
| | offsets: A Tensor of shape [2, H, W] of raw offset output. The order of |
| | second dim is (offset_y, offset_x). |
| | thing_ids: A set of ids from contiguous category ids belonging |
| | to thing categories. |
| | label_divisor: An integer, used to convert panoptic id = |
| | semantic id * label_divisor + instance_id. |
| | stuff_area: An integer, remove stuff whose area is less tan stuff_area. |
| | void_label: An integer, indicates the region has no confident prediction. |
| | threshold: A float, threshold applied to center heatmap score. |
| | nms_kernel: An integer, NMS max pooling kernel size. |
| | top_k: An integer, top k centers to keep. |
| | foreground_mask: Optional, A Tensor of shape [1, H, W] of predicted |
| | binary foreground mask. If not provided, it will be generated from |
| | sem_seg. |
| | Returns: |
| | A Tensor of shape [1, H, W], int64. |
| | """ |
| | if sem_seg.dim() != 3 and sem_seg.size(0) != 1: |
| | raise ValueError("Semantic prediction with un-supported shape: {}.".format(sem_seg.size())) |
| | if center_heatmap.dim() != 3: |
| | raise ValueError( |
| | "Center prediction with un-supported dimension: {}.".format(center_heatmap.dim()) |
| | ) |
| | if offsets.dim() != 3: |
| | raise ValueError("Offset prediction with un-supported dimension: {}.".format(offsets.dim())) |
| | if foreground_mask is not None: |
| | if foreground_mask.dim() != 3 and foreground_mask.size(0) != 1: |
| | raise ValueError( |
| | "Foreground prediction with un-supported shape: {}.".format(sem_seg.size()) |
| | ) |
| | thing_seg = foreground_mask |
| | else: |
| | |
| | thing_seg = torch.zeros_like(sem_seg) |
| | for thing_class in list(thing_ids): |
| | thing_seg[sem_seg == thing_class] = 1 |
| |
|
| | instance, center = get_instance_segmentation( |
| | sem_seg, |
| | center_heatmap, |
| | offsets, |
| | thing_seg, |
| | thing_ids, |
| | threshold=threshold, |
| | nms_kernel=nms_kernel, |
| | top_k=top_k, |
| | ) |
| | panoptic = merge_semantic_and_instance( |
| | sem_seg, instance, thing_seg, label_divisor, thing_ids, stuff_area, void_label |
| | ) |
| |
|
| | return panoptic, center |
| |
|