| import numpy as np | |
| import torchvision.transforms as T | |
| import torchvision.transforms.functional as TF | |
| import torch | |
| from scipy.ndimage import gaussian_filter, map_coordinates | |
| class AdvancedAugmentations: | |
| def __init__(self, target_size=(1024, 1024)): | |
| self.target_size = target_size | |
| def __call__(self, image, heatmap): | |
| image = TF.to_pil_image(image) | |
| heatmap = TF.to_pil_image(heatmap) | |
| if np.random.rand() > 0.5: | |
| image = TF.hflip(image) | |
| heatmap = TF.hflip(heatmap) | |
| if np.random.rand() > 0.5: | |
| image = TF.vflip(image) | |
| heatmap = TF.vflip(heatmap) | |
| if np.random.rand() > 0.5: | |
| angle = np.random.uniform(-45, 45) | |
| image = TF.rotate(image, angle, interpolation=TF.InterpolationMode.BILINEAR) | |
| heatmap = TF.rotate(heatmap, angle, interpolation=TF.InterpolationMode.BILINEAR) | |
| if np.random.rand() > 0.5: | |
| width, height = image.size | |
| crop_size = int(min(width, height) * np.random.uniform(0.8, 1.0)) | |
| i, j, h, w = T.RandomCrop.get_params(image, (crop_size, crop_size)) | |
| image = TF.crop(image, i, j, h, w) | |
| heatmap = TF.crop(heatmap, i, j, h, w) | |
| image = TF.resize(image, self.target_size, interpolation=TF.InterpolationMode.BILINEAR) | |
| heatmap = TF.resize(heatmap, self.target_size, interpolation=TF.InterpolationMode.BILINEAR) | |
| if np.random.rand() > 0.5: | |
| image, heatmap = self.random_affine(image, heatmap) | |
| if not isinstance(image, torch.Tensor): | |
| image = TF.to_tensor(image) | |
| if not isinstance(heatmap, torch.Tensor): | |
| heatmap = TF.to_tensor(heatmap) | |
| if np.random.rand() > 0.5: | |
| brightness_factor = np.random.uniform(0.8, 1.2) | |
| image = TF.adjust_brightness(image, brightness_factor) | |
| if np.random.rand() > 0.5: | |
| contrast_factor = np.random.uniform(0.8, 1.2) | |
| image = TF.adjust_contrast(image, contrast_factor) | |
| if np.random.rand() > 0.5: | |
| noise_level = np.random.uniform(0.01, 0.05) | |
| noise = torch.randn_like(image) * noise_level | |
| image = torch.clamp(image + noise, 0, 1) | |
| if np.random.rand() > 0.5: | |
| image, heatmap = self.elastic_transform(image, heatmap) | |
| return image, heatmap | |
| def random_affine(self, image, heatmap): | |
| degrees = [-10.0, 10.0] | |
| translate = [0.05, 0.05] | |
| scale = [0.95, 1.05] | |
| shear = [-5.0, 5.0] | |
| params = T.RandomAffine.get_params(degrees, translate, scale, shear, image.size) | |
| angle, translate, scale, shear = params | |
| translate = list(translate) | |
| shear = list(shear) | |
| image = TF.affine(image, angle, translate, scale, shear, interpolation=TF.InterpolationMode.BILINEAR) | |
| heatmap = TF.affine(heatmap, angle, translate, scale, shear, interpolation=TF.InterpolationMode.BILINEAR) | |
| return image, heatmap | |
| def elastic_transform(self, image, heatmap, alpha=50, sigma=4): | |
| if isinstance(image, torch.Tensor): | |
| image_np = image.permute(1, 2, 0).numpy() | |
| heatmap_np = heatmap.permute(1, 2, 0).numpy() | |
| else: | |
| image_np = np.asarray(image) | |
| heatmap_np = np.asarray(heatmap) | |
| if image_np.ndim == 2: | |
| image_np = image_np[:, :, np.newaxis] | |
| if heatmap_np.ndim == 2: | |
| heatmap_np = heatmap_np[:, :, np.newaxis] | |
| shape = image_np.shape[:2] | |
| dx = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha | |
| dy = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha | |
| x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0])) | |
| indices = np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1)) | |
| image_transformed = np.zeros_like(image_np) | |
| heatmap_transformed = np.zeros_like(heatmap_np) | |
| for i in range(image_np.shape[2]): | |
| image_transformed[..., i] = map_coordinates(image_np[..., i], indices, order=1).reshape(shape) | |
| for i in range(heatmap_np.shape[2]): | |
| heatmap_transformed[..., i] = map_coordinates(heatmap_np[..., i], indices, order=1).reshape(shape) | |
| return torch.from_numpy(image_transformed).float().permute(2, 0, 1), \ | |
| torch.from_numpy(heatmap_transformed).float().permute(2, 0, 1) | |