File size: 4,470 Bytes
2b9ff22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import numpy as np
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import torch
from scipy.ndimage import gaussian_filter, map_coordinates


class AdvancedAugmentations:
    def __init__(self, target_size=(1024, 1024)):
        self.target_size = target_size

    def __call__(self, image, heatmap):
        image = TF.to_pil_image(image)
        heatmap = TF.to_pil_image(heatmap)

        if np.random.rand() > 0.5:
            image = TF.hflip(image)
            heatmap = TF.hflip(heatmap)
        if np.random.rand() > 0.5:
            image = TF.vflip(image)
            heatmap = TF.vflip(heatmap)

        if np.random.rand() > 0.5:
            angle = np.random.uniform(-45, 45)
            image = TF.rotate(image, angle, interpolation=TF.InterpolationMode.BILINEAR)
            heatmap = TF.rotate(heatmap, angle, interpolation=TF.InterpolationMode.BILINEAR)

        if np.random.rand() > 0.5:
            width, height = image.size
            crop_size = int(min(width, height) * np.random.uniform(0.8, 1.0))
            i, j, h, w = T.RandomCrop.get_params(image, (crop_size, crop_size))
            image = TF.crop(image, i, j, h, w)
            heatmap = TF.crop(heatmap, i, j, h, w)
            image = TF.resize(image, self.target_size, interpolation=TF.InterpolationMode.BILINEAR)
            heatmap = TF.resize(heatmap, self.target_size, interpolation=TF.InterpolationMode.BILINEAR)

        if np.random.rand() > 0.5:
            image, heatmap = self.random_affine(image, heatmap)

        if not isinstance(image, torch.Tensor):
            image = TF.to_tensor(image)
        if not isinstance(heatmap, torch.Tensor):
            heatmap = TF.to_tensor(heatmap)

        if np.random.rand() > 0.5:
            brightness_factor = np.random.uniform(0.8, 1.2)
            image = TF.adjust_brightness(image, brightness_factor)
        if np.random.rand() > 0.5:
            contrast_factor = np.random.uniform(0.8, 1.2)
            image = TF.adjust_contrast(image, contrast_factor)

        if np.random.rand() > 0.5:
            noise_level = np.random.uniform(0.01, 0.05)
            noise = torch.randn_like(image) * noise_level
            image = torch.clamp(image + noise, 0, 1)

        if np.random.rand() > 0.5:
            image, heatmap = self.elastic_transform(image, heatmap)

        return image, heatmap

    def random_affine(self, image, heatmap):
        degrees = [-10.0, 10.0]
        translate = [0.05, 0.05]
        scale = [0.95, 1.05]
        shear = [-5.0, 5.0]
        params = T.RandomAffine.get_params(degrees, translate, scale, shear, image.size)
        angle, translate, scale, shear = params
        translate = list(translate)
        shear = list(shear)
        image = TF.affine(image, angle, translate, scale, shear, interpolation=TF.InterpolationMode.BILINEAR)
        heatmap = TF.affine(heatmap, angle, translate, scale, shear, interpolation=TF.InterpolationMode.BILINEAR)
        return image, heatmap

    def elastic_transform(self, image, heatmap, alpha=50, sigma=4):
        if isinstance(image, torch.Tensor):
            image_np = image.permute(1, 2, 0).numpy()
            heatmap_np = heatmap.permute(1, 2, 0).numpy()
        else:
            image_np = np.asarray(image)
            heatmap_np = np.asarray(heatmap)
            if image_np.ndim == 2:
                image_np = image_np[:, :, np.newaxis]
            if heatmap_np.ndim == 2:
                heatmap_np = heatmap_np[:, :, np.newaxis]

        shape = image_np.shape[:2]
        dx = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha
        dy = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha
        x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
        indices = np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1))

        image_transformed = np.zeros_like(image_np)
        heatmap_transformed = np.zeros_like(heatmap_np)
        for i in range(image_np.shape[2]):
            image_transformed[..., i] = map_coordinates(image_np[..., i], indices, order=1).reshape(shape)
        for i in range(heatmap_np.shape[2]):
            heatmap_transformed[..., i] = map_coordinates(heatmap_np[..., i], indices, order=1).reshape(shape)

        return torch.from_numpy(image_transformed).float().permute(2, 0, 1), \
               torch.from_numpy(heatmap_transformed).float().permute(2, 0, 1)