| import warnings |
| from dataclasses import dataclass, asdict |
| from typing import Any, Dict, Optional, Sequence, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torchvision.transforms.functional as F |
| from functools import partial |
| from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ |
| CenterCrop |
|
|
| from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD |
|
|
|
|
| @dataclass |
| class AugmentationCfg: |
| scale: Tuple[float, float] = (0.9, 1.0) |
| ratio: Optional[Tuple[float, float]] = None |
| color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None |
| interpolation: Optional[str] = None |
| re_prob: Optional[float] = None |
| re_count: Optional[int] = None |
| use_timm: bool = False |
|
|
|
|
| class ResizeMaxSize(nn.Module): |
|
|
| def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): |
| super().__init__() |
| if not isinstance(max_size, int): |
| raise TypeError(f"Size should be int. Got {type(max_size)}") |
| self.max_size = max_size |
| self.interpolation = interpolation |
| self.fn = min if fn == 'min' else min |
| self.fill = fill |
|
|
| def forward(self, img): |
| if isinstance(img, torch.Tensor): |
| height, width = img.shape[1:] |
| else: |
| width, height = img.size |
| scale = self.max_size / float(max(height, width)) |
| if scale != 1.0: |
| new_size = tuple(round(dim * scale) for dim in (height, width)) |
| img = F.resize(img, new_size, self.interpolation) |
| pad_h = self.max_size - new_size[0] |
| pad_w = self.max_size - new_size[1] |
| img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) |
| return img |
|
|
|
|
| def _convert_to_rgb_or_rgba(image): |
| if image.mode == 'RGBA': |
| return image |
| else: |
| return image.convert('RGB') |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| class MaskAwareNormalize(nn.Module): |
| def __init__(self, mean, std): |
| super().__init__() |
| self.normalize = Normalize(mean=mean, std=std) |
|
|
| def forward(self, tensor): |
| if tensor.shape[0] == 4: |
| return torch.cat([self.normalize(tensor[:3]), tensor[3:]], dim=0) |
| else: |
| return self.normalize(tensor) |
|
|
| def image_transform( |
| image_size: int, |
| is_train: bool, |
| mean: Optional[Tuple[float, ...]] = None, |
| std: Optional[Tuple[float, ...]] = None, |
| resize_longest_max: bool = False, |
| fill_color: int = 0, |
| aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, |
| ): |
| mean = mean or OPENAI_DATASET_MEAN |
| if not isinstance(mean, (list, tuple)): |
| mean = (mean,) * 3 |
|
|
| std = std or OPENAI_DATASET_STD |
| if not isinstance(std, (list, tuple)): |
| std = (std,) * 3 |
|
|
| if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: |
| |
| image_size = image_size[0] |
|
|
| if isinstance(aug_cfg, dict): |
| aug_cfg = AugmentationCfg(**aug_cfg) |
| else: |
| aug_cfg = aug_cfg or AugmentationCfg() |
| normalize = MaskAwareNormalize(mean=mean, std=std) |
| if is_train: |
| aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} |
| use_timm = aug_cfg_dict.pop('use_timm', False) |
| if use_timm: |
| assert False, "not tested for augmentation with mask" |
| from timm.data import create_transform |
| if isinstance(image_size, (tuple, list)): |
| assert len(image_size) >= 2 |
| input_size = (3,) + image_size[-2:] |
| else: |
| input_size = (3, image_size, image_size) |
| |
| aug_cfg_dict.setdefault('interpolation', 'random') |
| aug_cfg_dict.setdefault('color_jitter', None) |
| train_transform = create_transform( |
| input_size=input_size, |
| is_training=True, |
| hflip=0., |
| mean=mean, |
| std=std, |
| re_mode='pixel', |
| **aug_cfg_dict, |
| ) |
| else: |
| train_transform = Compose([ |
| _convert_to_rgb_or_rgba, |
| ToTensor(), |
| RandomResizedCrop( |
| image_size, |
| scale=aug_cfg_dict.pop('scale'), |
| interpolation=InterpolationMode.BICUBIC, |
| ), |
| normalize, |
| ]) |
| if aug_cfg_dict: |
| warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') |
| return train_transform |
| else: |
| transforms = [ |
| _convert_to_rgb_or_rgba, |
| ToTensor(), |
| ] |
| if resize_longest_max: |
| transforms.extend([ |
| ResizeMaxSize(image_size, fill=fill_color) |
| ]) |
| else: |
| transforms.extend([ |
| Resize(image_size, interpolation=InterpolationMode.BICUBIC), |
| CenterCrop(image_size), |
| ]) |
| transforms.extend([ |
| normalize, |
| ]) |
| return Compose(transforms) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |