| | import os |
| | import sys |
| |
|
| | import torch |
| | import torchvision |
| | from fvcore.nn import FlopCountAnalysis |
| | from torch import nn |
| |
|
| | sys.path.append("vision/references/segmentation") |
| | from transforms import Compose |
| | from coco_utils import ConvertCocoPolysToMask |
| | from coco_utils import FilterAndRemapCocoCategories |
| | from coco_utils import _coco_remove_images_without_annotations |
| | from utils import ConfusionMatrix |
| |
|
| |
|
| | class NanSafeConfusionMatrix(ConfusionMatrix): |
| | """Confusion matrix with replacement nans to zeros.""" |
| |
|
| | def __init__(self, num_classes): |
| | super().__init__(num_classes=num_classes) |
| |
|
| | def compute(self): |
| | """Compute metrics based on confusion matrix.""" |
| | confusion_matrix = self.mat.float() |
| | acc_global = torch.nan_to_num(torch.diag(confusion_matrix).sum() / confusion_matrix.sum()) |
| | acc = torch.nan_to_num(torch.diag(confusion_matrix) / confusion_matrix.sum(1)) |
| | intersection_over_unions = torch.nan_to_num( |
| | torch.diag(confusion_matrix) |
| | / (confusion_matrix.sum(1) + confusion_matrix.sum(0) - torch.diag(confusion_matrix)) |
| | ) |
| | return acc_global, acc, intersection_over_unions |
| |
|
| |
|
| | def flops_calculation_function(model: nn.Module, input_sample: torch.Tensor) -> float: |
| | """Calculate number of flops in millions.""" |
| | counter = FlopCountAnalysis( |
| | model=model.eval(), |
| | inputs=input_sample, |
| | ) |
| | counter.unsupported_ops_warnings(False) |
| | counter.uncalled_modules_warnings(False) |
| |
|
| | flops = counter.total() / input_sample.shape[0] |
| |
|
| | return flops / 1e6 |
| |
|
| |
|
| | def get_coco(root, image_set, transforms, use_v2=False, use_orig=False): |
| | """Get COCO dataset with VOC or COCO classes.""" |
| | paths = { |
| | "train": ("train2017", os.path.join("annotations", "instances_train2017.json")), |
| | "val": ("val2017", os.path.join("annotations", "instances_val2017.json")), |
| | |
| | } |
| | if use_orig: |
| | classes_list = list(range(81)) |
| | else: |
| | classes_list = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72] |
| |
|
| | img_folder, ann_file = paths[image_set] |
| | img_folder = os.path.join(root, img_folder) |
| | ann_file = os.path.join(root, ann_file) |
| |
|
| | |
| | |
| | |
| | |
| | if use_v2: |
| | import v2_extras |
| | from torchvision.datasets import wrap_dataset_for_transforms_v2 |
| |
|
| | transforms = Compose([v2_extras.CocoDetectionToVOCSegmentation(), transforms]) |
| | dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) |
| | dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"masks", "labels"}) |
| | else: |
| | transforms = Compose( |
| | [FilterAndRemapCocoCategories(classes_list, remap=True), ConvertCocoPolysToMask(), transforms] |
| | ) |
| | dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) |
| |
|
| | if image_set == "train": |
| | dataset = _coco_remove_images_without_annotations(dataset, classes_list) |
| |
|
| | return dataset |
| |
|