| | |
| | from typing import List, Tuple |
| |
|
| | import torch |
| | from mmcv.ops import point_sample |
| | from mmengine.structures import InstanceData |
| | from torch import Tensor |
| |
|
| | from mmseg.registry import TASK_UTILS |
| | from mmseg.utils import ConfigType, SampleList |
| |
|
| |
|
| | def seg_data_to_instance_data(ignore_index: int, |
| | batch_data_samples: SampleList): |
| | """Convert the paradigm of ground truth from semantic segmentation to |
| | instance segmentation. |
| | |
| | Args: |
| | ignore_index (int): The label index to be ignored. |
| | batch_data_samples (List[SegDataSample]): The Data |
| | Samples. It usually includes information such as |
| | `gt_sem_seg`. |
| | |
| | Returns: |
| | tuple[Tensor]: A tuple contains two lists. |
| | - batch_gt_instances (List[InstanceData]): Batch of |
| | gt_instance. It usually includes ``labels``, each is |
| | unique ground truth label id of images, with |
| | shape (num_gt, ) and ``masks``, each is ground truth |
| | masks of each instances of a image, shape (num_gt, h, w). |
| | - batch_img_metas (List[Dict]): List of image meta information. |
| | """ |
| | batch_gt_instances = [] |
| |
|
| | for data_sample in batch_data_samples: |
| | gt_sem_seg = data_sample.gt_sem_seg.data |
| | classes = torch.unique( |
| | gt_sem_seg, |
| | sorted=False, |
| | return_inverse=False, |
| | return_counts=False) |
| |
|
| | |
| | gt_labels = classes[classes != ignore_index] |
| |
|
| | masks = [] |
| | for class_id in gt_labels: |
| | masks.append(gt_sem_seg == class_id) |
| |
|
| | if len(masks) == 0: |
| | gt_masks = torch.zeros( |
| | (0, gt_sem_seg.shape[-2], |
| | gt_sem_seg.shape[-1])).to(gt_sem_seg).long() |
| | else: |
| | gt_masks = torch.stack(masks).squeeze(1).long() |
| |
|
| | instance_data = InstanceData(labels=gt_labels, masks=gt_masks) |
| | batch_gt_instances.append(instance_data) |
| | return batch_gt_instances |
| |
|
| |
|
| | class MatchMasks: |
| | """Match the predictions to category labels. |
| | |
| | Args: |
| | num_points (int): the number of sampled points to compute cost. |
| | num_queries (int): the number of prediction masks. |
| | num_classes (int): the number of classes. |
| | assigner (BaseAssigner): the assigner to compute matching. |
| | """ |
| |
|
| | def __init__(self, |
| | num_points: int, |
| | num_queries: int, |
| | num_classes: int, |
| | assigner: ConfigType = None): |
| | assert assigner is not None, "\'assigner\' in decode_head.train_cfg" \ |
| | 'cannot be None' |
| | assert num_points > 0, 'num_points should be a positive integer.' |
| | self.num_points = num_points |
| | self.num_queries = num_queries |
| | self.num_classes = num_classes |
| | self.assigner = TASK_UTILS.build(assigner) |
| |
|
| | def get_targets(self, cls_scores: List[Tensor], mask_preds: List[Tensor], |
| | batch_gt_instances: List[InstanceData]) -> Tuple: |
| | """Compute best mask matches for all images for a decoder layer. |
| | |
| | Args: |
| | cls_scores (List[Tensor]): Mask score logits from a single |
| | decoder layer for all images. Each with shape (num_queries, |
| | cls_out_channels). |
| | mask_preds (List[Tensor]): Mask logits from a single decoder |
| | layer for all images. Each with shape (num_queries, h, w). |
| | batch_gt_instances (List[InstanceData]): each contains |
| | ``labels`` and ``masks``. |
| | |
| | Returns: |
| | tuple: a tuple containing the following targets. |
| | |
| | - labels (List[Tensor]): Labels of all images.\ |
| | Each with shape (num_queries, ). |
| | - mask_targets (List[Tensor]): Mask targets of\ |
| | all images. Each with shape (num_queries, h, w). |
| | - mask_weights (List[Tensor]): Mask weights of\ |
| | all images. Each with shape (num_queries, ). |
| | - avg_factor (int): Average factor that is used to |
| | average the loss. `avg_factor` is usually equal |
| | to the number of positive priors. |
| | """ |
| | batch_size = cls_scores.shape[0] |
| | results = dict({ |
| | 'labels': [], |
| | 'mask_targets': [], |
| | 'mask_weights': [], |
| | }) |
| | for i in range(batch_size): |
| | labels, mask_targets, mask_weights\ |
| | = self._get_targets_single(cls_scores[i], |
| | mask_preds[i], |
| | batch_gt_instances[i]) |
| | results['labels'].append(labels) |
| | results['mask_targets'].append(mask_targets) |
| | results['mask_weights'].append(mask_weights) |
| |
|
| | |
| | labels = torch.stack(results['labels'], dim=0) |
| | |
| | mask_targets = torch.cat(results['mask_targets'], dim=0) |
| | |
| | mask_weights = torch.stack(results['mask_weights'], dim=0) |
| |
|
| | avg_factor = sum( |
| | [len(gt_instances.labels) for gt_instances in batch_gt_instances]) |
| |
|
| | res = (labels, mask_targets, mask_weights, avg_factor) |
| |
|
| | return res |
| |
|
| | def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor, |
| | gt_instances: InstanceData) \ |
| | -> Tuple[Tensor, Tensor, Tensor]: |
| | """Compute a set of best mask matches for one image. |
| | |
| | Args: |
| | cls_score (Tensor): Mask score logits from a single decoder layer |
| | for one image. Shape (num_queries, cls_out_channels). |
| | mask_pred (Tensor): Mask logits for a single decoder layer for one |
| | image. Shape (num_queries, h, w). |
| | gt_instances (:obj:`InstanceData`): It contains ``labels`` and |
| | ``masks``. |
| | |
| | Returns: |
| | tuple[Tensor]: A tuple containing the following for one image. |
| | |
| | - labels (Tensor): Labels of each image. \ |
| | shape (num_queries, ). |
| | - mask_targets (Tensor): Mask targets of each image. \ |
| | shape (num_queries, h, w). |
| | - mask_weights (Tensor): Mask weights of each image. \ |
| | shape (num_queries, ). |
| | """ |
| | gt_labels = gt_instances.labels |
| | gt_masks = gt_instances.masks |
| | |
| | if len(gt_labels) == 0: |
| | labels = gt_labels.new_full((self.num_queries, ), |
| | self.num_classes, |
| | dtype=torch.long) |
| | mask_targets = gt_labels |
| | mask_weights = gt_labels.new_zeros((self.num_queries, )) |
| | return labels, mask_targets, mask_weights |
| | |
| | num_queries = cls_score.shape[0] |
| | num_gts = gt_labels.shape[0] |
| |
|
| | point_coords = torch.rand((1, self.num_points, 2), |
| | device=cls_score.device) |
| | |
| | mask_points_pred = point_sample( |
| | mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, |
| | 1)).squeeze(1) |
| | |
| | gt_points_masks = point_sample( |
| | gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, |
| | 1)).squeeze(1) |
| |
|
| | sampled_gt_instances = InstanceData( |
| | labels=gt_labels, masks=gt_points_masks) |
| | sampled_pred_instances = InstanceData( |
| | scores=cls_score, masks=mask_points_pred) |
| | |
| | matched_quiery_inds, matched_label_inds = self.assigner.assign( |
| | pred_instances=sampled_pred_instances, |
| | gt_instances=sampled_gt_instances) |
| | labels = gt_labels.new_full((self.num_queries, ), |
| | self.num_classes, |
| | dtype=torch.long) |
| | labels[matched_quiery_inds] = gt_labels[matched_label_inds] |
| |
|
| | mask_weights = gt_labels.new_zeros((self.num_queries, )) |
| | mask_weights[matched_quiery_inds] = 1 |
| | mask_targets = gt_masks[matched_label_inds] |
| |
|
| | return labels, mask_targets, mask_weights |
| |
|