| | |
| | |
| | """ |
| | Misc functions, including distributed helpers. |
| | |
| | Mostly copy-paste from torchvision references. |
| | """ |
| | from typing import List, Optional |
| | from collections import OrderedDict |
| | from scipy.io import loadmat |
| | import numpy as np |
| | import csv |
| | from PIL import Image |
| | import matplotlib.pyplot as plt |
| | import torch |
| | import torch.distributed as dist |
| | import torchvision |
| | from torch import Tensor |
| |
|
| |
|
| | def _max_by_axis(the_list): |
| | |
| | maxes = the_list[0] |
| | for sublist in the_list[1:]: |
| | for index, item in enumerate(sublist): |
| | maxes[index] = max(maxes[index], item) |
| | return maxes |
| |
|
| | def get_world_size() -> int: |
| | if not dist.is_available(): |
| | return 1 |
| | if not dist.is_initialized(): |
| | return 1 |
| | return dist.get_world_size() |
| |
|
| | def reduce_dict(input_dict, average=True): |
| | """ |
| | Args: |
| | input_dict (dict): all the values will be reduced |
| | average (bool): whether to do average or sum |
| | Reduce the values in the dictionary from all processes so that all processes |
| | have the averaged results. Returns a dict with the same fields as |
| | input_dict, after reduction. |
| | """ |
| | world_size = get_world_size() |
| | if world_size < 2: |
| | return input_dict |
| | with torch.no_grad(): |
| | names = [] |
| | values = [] |
| | |
| | for k in sorted(input_dict.keys()): |
| | names.append(k) |
| | values.append(input_dict[k]) |
| | values = torch.stack(values, dim=0) |
| | dist.all_reduce(values) |
| | if average: |
| | values /= world_size |
| | reduced_dict = {k: v for k, v in zip(names, values)} |
| | return reduced_dict |
| |
|
| | class NestedTensor(object): |
| | def __init__(self, tensors, mask: Optional[Tensor]): |
| | self.tensors = tensors |
| | self.mask = mask |
| |
|
| | def to(self, device): |
| | |
| | cast_tensor = self.tensors.to(device) |
| | mask = self.mask |
| | if mask is not None: |
| | assert mask is not None |
| | cast_mask = mask.to(device) |
| | else: |
| | cast_mask = None |
| | return NestedTensor(cast_tensor, cast_mask) |
| |
|
| | def decompose(self): |
| | return self.tensors, self.mask |
| |
|
| | def __repr__(self): |
| | return str(self.tensors) |
| |
|
| | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): |
| | |
| | if tensor_list[0].ndim == 3: |
| | if torchvision._is_tracing(): |
| | |
| | |
| | return _onnx_nested_tensor_from_tensor_list(tensor_list) |
| |
|
| | |
| | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) |
| | |
| | batch_shape = [len(tensor_list)] + max_size |
| | b, c, h, w = batch_shape |
| | dtype = tensor_list[0].dtype |
| | device = tensor_list[0].device |
| | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) |
| | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) |
| | for img, pad_img, m in zip(tensor_list, tensor, mask): |
| | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) |
| | m[: img.shape[1], : img.shape[2]] = False |
| | else: |
| | raise ValueError("not supported") |
| | return NestedTensor(tensor, mask) |
| |
|
| | |
| | |
| | @torch.jit.unused |
| | def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: |
| | max_size = [] |
| | for i in range(tensor_list[0].dim()): |
| | max_size_i = torch.max( |
| | torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) |
| | ).to(torch.int64) |
| | max_size.append(max_size_i) |
| | max_size = tuple(max_size) |
| |
|
| | |
| | |
| | |
| | |
| | padded_imgs = [] |
| | padded_masks = [] |
| | for img in tensor_list: |
| | padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] |
| | padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) |
| | padded_imgs.append(padded_img) |
| |
|
| | m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) |
| | padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) |
| | padded_masks.append(padded_mask.to(torch.bool)) |
| |
|
| | tensor = torch.stack(padded_imgs) |
| | mask = torch.stack(padded_masks) |
| |
|
| | return NestedTensor(tensor, mask=mask) |
| |
|
| | def is_dist_avail_and_initialized(): |
| | if not dist.is_available(): |
| | return False |
| | if not dist.is_initialized(): |
| | return False |
| | return True |
| |
|
| | def load_parallal_model(model, state_dict_): |
| | state_dict = OrderedDict() |
| | for key in state_dict_: |
| | if key.startswith('module') and not key.startswith('module_list'): |
| | state_dict[key[7:]] = state_dict_[key] |
| | else: |
| | state_dict[key] = state_dict_[key] |
| |
|
| | |
| | model_state_dict = model.state_dict() |
| | for key in state_dict: |
| | if key in model_state_dict: |
| | if state_dict[key].shape != model_state_dict[key].shape: |
| | print('Skip loading parameter {}, required shape{}, loaded shape{}.'.format( |
| | key, model_state_dict[key].shape, state_dict[key].shape)) |
| | state_dict[key] = model_state_dict[key] |
| | else: |
| | print('Drop parameter {}.'.format(key)) |
| | for key in model_state_dict: |
| | if key not in state_dict: |
| | print('No param {}.'.format(key)) |
| | state_dict[key] = model_state_dict[key] |
| | model.load_state_dict(state_dict, strict=False) |
| |
|
| | return model |
| |
|
| | class ADEVisualize(object): |
| | def __init__(self): |
| | self.colors = loadmat('dataset/color150.mat')['colors'] |
| | self.names = {} |
| | with open('dataset/object150_info.csv') as f: |
| | reader = csv.reader(f) |
| | next(reader) |
| | for row in reader: |
| | self.names[int(row[0])] = row[5].split(";")[0] |
| |
|
| | def unique(self, ar, return_index=False, return_inverse=False, return_counts=False): |
| | ar = np.asanyarray(ar).flatten() |
| |
|
| | optional_indices = return_index or return_inverse |
| | optional_returns = optional_indices or return_counts |
| |
|
| | if ar.size == 0: |
| | if not optional_returns: |
| | ret = ar |
| | else: |
| | ret = (ar,) |
| | if return_index: |
| | ret += (np.empty(0, np.bool),) |
| | if return_inverse: |
| | ret += (np.empty(0, np.bool),) |
| | if return_counts: |
| | ret += (np.empty(0, np.intp),) |
| | return ret |
| | if optional_indices: |
| | perm = ar.argsort(kind='mergesort' if return_index else 'quicksort') |
| | aux = ar[perm] |
| | else: |
| | ar.sort() |
| | aux = ar |
| | flag = np.concatenate(([True], aux[1:] != aux[:-1])) |
| |
|
| | if not optional_returns: |
| | ret = aux[flag] |
| | else: |
| | ret = (aux[flag],) |
| | if return_index: |
| | ret += (perm[flag],) |
| | if return_inverse: |
| | iflag = np.cumsum(flag) - 1 |
| | inv_idx = np.empty(ar.shape, dtype=np.intp) |
| | inv_idx[perm] = iflag |
| | ret += (inv_idx,) |
| | if return_counts: |
| | idx = np.concatenate(np.nonzero(flag) + ([ar.size],)) |
| | ret += (np.diff(idx),) |
| | return ret |
| |
|
| | def colorEncode(self, labelmap, colors, mode='RGB'): |
| | labelmap = labelmap.astype('int') |
| | labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), |
| | dtype=np.uint8) |
| | for label in self.unique(labelmap): |
| | if label < 0: |
| | continue |
| | labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \ |
| | np.tile(colors[label], |
| | (labelmap.shape[0], labelmap.shape[1], 1)) |
| |
|
| | if mode == 'BGR': |
| | return labelmap_rgb[:, :, ::-1] |
| | else: |
| | return labelmap_rgb |
| |
|
| | def show_result(self, img, pred, save_path=None): |
| | pred = np.int32(pred) |
| | |
| | pred_color = self.colorEncode(pred, self.colors) |
| | pil_img = img.convert('RGBA') |
| | pred_color = Image.fromarray(pred_color).convert('RGBA') |
| | im_vis = Image.blend(pil_img, pred_color, 0.6) |
| | if save_path is not None: |
| | im_vis.save(save_path) |
| | |
| | else: |
| | plt.imshow(im_vis) |