| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Train and eval functions used in main.py |
| """ |
| import math |
| import os |
| import sys |
| from typing import Iterable |
|
|
| import torch |
| import util.misc as utils |
| from datasets.coco_eval import CocoEvaluator, convert_to_xywh |
| from datasets.data_prefetcher import data_prefetcher |
| from datasets.panoptic_eval import PanopticEvaluator |
| from util.ema import requires_grad, update_ema |
| from util.misc import NestedTensor |
|
|
|
|
| def train_one_epoch( |
| model: torch.nn.Module, |
| criterion: torch.nn.Module, |
| data_loader: Iterable, |
| optimizer: torch.optim.Optimizer, |
| device: torch.device, |
| epoch: int, |
| max_norm: float = 0, |
| ema: torch.nn.Module = None, |
| ema_decay: float = 0.999, |
| ): |
| model.train() |
| criterion.train() |
| metric_logger = utils.MetricLogger(delimiter=" ") |
| metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}")) |
| metric_logger.add_meter( |
| "class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}") |
| ) |
| metric_logger.add_meter( |
| "grad_norm", utils.SmoothedValue(window_size=1, fmt="{value:.2f}") |
| ) |
| header = "Epoch: [{}]".format(epoch) |
| print_freq = 10 |
|
|
| prefetcher = data_prefetcher(data_loader, device, prefetch=True) |
| samples, targets = prefetcher.next() |
|
|
| |
| for _ in metric_logger.log_every(range(len(data_loader)), print_freq, header): |
| outputs = model(samples) |
| loss_dict = criterion(outputs, targets) |
| weight_dict = criterion.weight_dict |
| losses = sum( |
| loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict |
| ) |
|
|
| |
| loss_dict_reduced = utils.reduce_dict(loss_dict) |
| loss_dict_reduced_unscaled = { |
| f"{k}_unscaled": v for k, v in loss_dict_reduced.items() |
| } |
| loss_dict_reduced_scaled = { |
| k: v * weight_dict[k] |
| for k, v in loss_dict_reduced.items() |
| if k in weight_dict |
| } |
| losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) |
|
|
| loss_value = losses_reduced_scaled.item() |
|
|
| if not math.isfinite(loss_value): |
| print("Loss is {}, stopping training".format(loss_value)) |
| print(loss_dict_reduced) |
| sys.exit(1) |
|
|
| optimizer.zero_grad() |
| losses.backward() |
| if max_norm > 0: |
| grad_total_norm = torch.nn.utils.clip_grad_norm_( |
| model.parameters(), max_norm |
| ) |
| else: |
| grad_total_norm = utils.get_total_grad_norm(model.parameters(), max_norm) |
| optimizer.step() |
|
|
| if ema is not None: |
| update_ema(ema, model.module, ema_decay) |
| |
|
|
| metric_logger.update( |
| loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled |
| ) |
| metric_logger.update(class_error=loss_dict_reduced["class_error"]) |
| metric_logger.update(lr=optimizer.param_groups[0]["lr"]) |
| metric_logger.update(grad_norm=grad_total_norm) |
|
|
| samples, targets = prefetcher.next() |
| |
| metric_logger.synchronize_between_processes() |
| print("Averaged stats:", metric_logger) |
| return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
| @torch.no_grad() |
| def evaluate( |
| model_no_ema, |
| criterion, |
| postprocessors, |
| data_loader, |
| base_ds, |
| device, |
| output_dir, |
| test_hflip_aug, |
| tta, |
| soft_nms, |
| ema=None, |
| save_result=False, |
| save_result_dir="", |
| soft_nms_method="quad", |
| nms_thresh=0.7, |
| quad_scale=0.5, |
| lsj_img_size=1824, |
| ): |
| model = model_no_ema if ema is None else ema |
| model.eval() |
| criterion.eval() |
|
|
| metric_logger = utils.MetricLogger(delimiter=" ") |
| metric_logger.add_meter( |
| "class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}") |
| ) |
| header = "Test:" |
|
|
| iou_types = tuple(k for k in ("segm", "bbox") if k in postprocessors.keys()) |
| coco_evaluator = CocoEvaluator(base_ds, iou_types) |
| |
|
|
| panoptic_evaluator = None |
| if "panoptic" in postprocessors.keys(): |
| panoptic_evaluator = PanopticEvaluator( |
| data_loader.dataset.ann_file, |
| data_loader.dataset.ann_folder, |
| output_dir=os.path.join(output_dir, "panoptic_eval"), |
| ) |
|
|
| prediction_list = [] |
| for samples, targets in metric_logger.log_every(data_loader, 10, header): |
| samples = samples.to(device) |
| targets = [{k: v.to(device) for k, v in t.items()} for t in targets] |
|
|
| if test_hflip_aug: |
| assert ( |
| samples.tensors.shape[0] == 1 |
| ), "test_hflip_aug only supports batch size 1" |
| assert ( |
| samples.tensors.shape[1] == 6 |
| ), "test_hflip_aug requires two images in a batch" |
| first_samples = NestedTensor(samples.tensors[:, :3], samples.mask) |
| outputs = model(first_samples) |
| flipped_samples = NestedTensor(samples.tensors[:, 3:], samples.mask) |
| flipped_outputs = model(flipped_samples) |
| else: |
| outputs = model(samples) |
| loss_dict = criterion(outputs, targets) |
| weight_dict = criterion.weight_dict |
|
|
| |
| loss_dict_reduced = utils.reduce_dict(loss_dict) |
| loss_dict_reduced_scaled = { |
| k: v * weight_dict[k] |
| for k, v in loss_dict_reduced.items() |
| if k in weight_dict |
| } |
| loss_dict_reduced_unscaled = { |
| f"{k}_unscaled": v for k, v in loss_dict_reduced.items() |
| } |
| metric_logger.update( |
| loss=sum(loss_dict_reduced_scaled.values()), |
| **loss_dict_reduced_scaled, |
| **loss_dict_reduced_unscaled, |
| ) |
| metric_logger.update(class_error=loss_dict_reduced["class_error"]) |
|
|
| orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) |
| if test_hflip_aug: |
| new_outputs = {} |
| pred_logits = outputs["pred_logits"] |
| pred_boxes = outputs["pred_boxes"] |
|
|
| flipped_pred_logits = flipped_outputs["pred_logits"] |
| flipped_pred_boxes = flipped_outputs["pred_boxes"] |
|
|
| reflipped_pred_boxes = flipped_pred_boxes[ |
| :, :, [0, 1, 2, 3] |
| ] * torch.as_tensor([-1, 1, 1, 1]).to( |
| flipped_pred_boxes.device |
| ) + torch.as_tensor( |
| [1, 0, 0, 0] |
| ).to( |
| flipped_pred_boxes.device |
| ) |
|
|
| new_pred_logits = torch.cat([pred_logits, flipped_pred_logits], dim=1) |
| new_pred_boxes = torch.cat([pred_boxes, reflipped_pred_boxes], dim=1) |
|
|
| new_outputs["pred_logits"] = new_pred_logits |
| new_outputs["pred_boxes"] = new_pred_boxes |
| results = postprocessors["bbox"]( |
| new_outputs, |
| orig_target_sizes, |
| soft_nms=soft_nms, |
| method=soft_nms_method, |
| nms_thresh=nms_thresh, |
| quad_scale=quad_scale, |
| ) |
| else: |
| results = postprocessors["bbox"]( |
| outputs, |
| orig_target_sizes, |
| soft_nms=soft_nms, |
| method=soft_nms_method, |
| nms_thresh=nms_thresh, |
| quad_scale=quad_scale, |
| ) |
| if "segm" in postprocessors.keys(): |
| target_sizes = torch.stack([t["size"] for t in targets], dim=0) |
| results = postprocessors["segm"]( |
| results, outputs, orig_target_sizes, target_sizes |
| ) |
| res = { |
| target["image_id"].item(): output |
| for target, output in zip(targets, results) |
| } |
| if coco_evaluator is not None: |
| coco_evaluator.update(res) |
|
|
| if panoptic_evaluator is not None: |
| res_pano = postprocessors["panoptic"]( |
| outputs, target_sizes, orig_target_sizes |
| ) |
| for i, target in enumerate(targets): |
| image_id = target["image_id"].item() |
| file_name = f"{image_id:012d}.png" |
| res_pano[i]["image_id"] = image_id |
| res_pano[i]["file_name"] = file_name |
|
|
| panoptic_evaluator.update(res_pano) |
|
|
| for target, output in zip(targets, results): |
| res_cpu = { |
| target["image_id"].item(): { |
| "boxes": output["boxes"].cpu(), |
| "labels": output["labels"].cpu(), |
| "scores": output["scores"].cpu(), |
| } |
| } |
| prediction_list.append(res_cpu) |
|
|
| |
| metric_logger.synchronize_between_processes() |
| print("Averaged stats:", metric_logger) |
|
|
| if save_result: |
|
|
| from torch import distributed as dist |
|
|
| os.makedirs(save_result_dir, exist_ok=True) |
| rank = dist.get_rank() |
| torch.save( |
| prediction_list, |
| os.path.join(save_result_dir, f"val2017_prediction_{rank}.pth"), |
| ) |
|
|
| if coco_evaluator is not None: |
| coco_evaluator.synchronize_between_processes() |
| if panoptic_evaluator is not None: |
| panoptic_evaluator.synchronize_between_processes() |
|
|
| |
| if coco_evaluator is not None: |
| coco_evaluator.accumulate() |
| coco_evaluator.summarize() |
| panoptic_res = None |
| if panoptic_evaluator is not None: |
| panoptic_res = panoptic_evaluator.summarize() |
| stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
| if coco_evaluator is not None: |
| if "bbox" in postprocessors.keys(): |
| stats["coco_eval_bbox"] = coco_evaluator.coco_eval["bbox"].stats.tolist() |
| if "segm" in postprocessors.keys(): |
| stats["coco_eval_masks"] = coco_evaluator.coco_eval["segm"].stats.tolist() |
| if panoptic_res is not None: |
| stats["PQ_all"] = panoptic_res["All"] |
| stats["PQ_th"] = panoptic_res["Things"] |
| stats["PQ_st"] = panoptic_res["Stuff"] |
| return stats, coco_evaluator |
|
|
|
|