File size: 10,290 Bytes
2b9ff22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f83fb5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
"""
Dataset and data loading for S2F training.
Expects folder structure: each subfolder has BF_001.tif (bright field), *_gray.jpg (heatmap), and optionally .txt (cell_area, sum_force).
"""
import os
import cv2
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from concurrent.futures import ThreadPoolExecutor
import numpy as np

from utils import config


def blur_force_map(force_map, ksize=25, sigma=10):
    if ksize % 2 == 0:
        ksize += 1
    if force_map.dim() == 3:
        force_map = force_map.unsqueeze(0)
    device = force_map.device
    force_map = force_map.cpu()
    blurred_maps = []
    for i in range(force_map.size(0)):
        force_np = force_map[i, 0].numpy().astype(np.float32)
        blurred = cv2.GaussianBlur(force_np, (ksize, ksize), sigmaX=sigma)
        blurred_maps.append(blurred)
    return torch.from_numpy(np.stack(blurred_maps)).to(device)


class ImageDataset(Dataset):
    def __init__(self, image_pairs, transform=None, channel_first=True,
                 blur_heatmap=False, threshold=0.0, return_metadata=False):
        self.image_pairs = image_pairs
        self.transform = transform
        self.channel_first = channel_first
        self.blur_heatmap = blur_heatmap
        self.threshold = threshold
        self.return_metadata = return_metadata

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, idx):
        if self.return_metadata:
            bf_image, hm_image, numbers, metadata = self.image_pairs[idx]
        else:
            bf_image, hm_image, numbers = self.image_pairs[idx]
        if isinstance(numbers, tuple):
            cell_area, sum_force = numbers
        else:
            cell_area = 0
            sum_force = numbers

        image = torch.from_numpy(bf_image).float().unsqueeze(0)
        heatmap = torch.from_numpy(hm_image).float().unsqueeze(0)
        if self.transform:
            image, heatmap = self.transform(image, heatmap)
        cell_area = torch.tensor(cell_area, dtype=torch.float32)
        sum_force = torch.tensor(sum_force, dtype=torch.float32)
        heatmap[heatmap <= self.threshold] = 0
        if self.blur_heatmap:
            heatmap = blur_force_map(heatmap)
        if not self.channel_first:
            image = image.permute(2, 1, 0)
            heatmap = heatmap.permute(2, 1, 0)
        if self.return_metadata:
            return image, heatmap, cell_area, sum_force, metadata
        return image, heatmap, cell_area, sum_force


def load_image(filepath, target_size):
    img = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
    if isinstance(target_size, int):
        target_size = (target_size, target_size)
    img = cv2.resize(img, target_size)
    img = img / 255.0
    return img.astype(np.float32)


def load_text_data(filepath):
    with open(filepath, 'r') as f:
        lines = [line.strip() for line in f if line.strip()]
        cell_area_diff = float(lines[0].split(":")[1].strip()) * config.SCALE_FACTOR_AREA
        sum_force_diff = float(lines[1].split(":")[1].strip()) * config.SCALE_FACTOR_FORCE
        return (cell_area_diff, sum_force_diff)


def load_images_from_subfolders(root_folder, target_size, load_numerical_data=True,
                                load_force_sum=False, return_metadata=False, substrate=None):
    paired_images = []
    numerical_data = []
    metadata = []
    for subfolder in os.listdir(root_folder):
        subfolder_path = os.path.join(root_folder, subfolder)
        if not os.path.isdir(subfolder_path):
            continue
        bf_image_path = hm_image_path = txt_file_path = None
        for filename in os.listdir(subfolder_path):
            if filename.endswith("BF_001.tif"):
                bf_image_path = os.path.join(subfolder_path, filename)
            elif filename.endswith("_gray.jpg"):
                hm_image_path = os.path.join(subfolder_path, filename)
            elif filename.endswith(".txt"):
                txt_file_path = os.path.join(subfolder_path, filename)

        if return_metadata:
            if substrate is None:
                from utils.substrate_settings import list_substrates
                raise ValueError("substrate must be passed when return_metadata=True. Options: " +
                                 ", ".join(list_substrates()))
            metadata.append({'folder_name': subfolder, 'substrate': substrate, 'root_folder': root_folder})

        if load_numerical_data:
            if bf_image_path and hm_image_path and txt_file_path:
                paired_images.append((bf_image_path, hm_image_path))
                numerical_data.append(load_text_data(txt_file_path))
        elif load_force_sum:
            if bf_image_path and hm_image_path:
                paired_images.append((bf_image_path, hm_image_path))
                hm = load_image(hm_image_path, target_size)
                numerical_data.append((0, float(np.sum(hm)) * config.SCALE_FACTOR_FORCE))
        else:
            if bf_image_path and hm_image_path:
                paired_images.append((bf_image_path, hm_image_path))

    with ThreadPoolExecutor() as executor:
        bf_loaded = list(executor.map(lambda p: load_image(p[0], target_size), paired_images))
        hm_loaded = list(executor.map(lambda p: load_image(p[1], target_size), paired_images))
    if not numerical_data:
        numerical_data = [(0, 0)] * len(bf_loaded)
    if return_metadata:
        return list(zip(bf_loaded, hm_loaded, numerical_data, metadata))
    return list(zip(bf_loaded, hm_loaded, numerical_data))


def prepare_data(input_folder, batch_size=8, target_size=(1024, 1024), split_size=0.2,
                 use_augmentations=True, train_test_sep_folder=True, channel_first=True,
                 load_numerical_data=False, load_force_sum=False, blur_heatmap=False,
                 threshold=0.0, return_metadata=False, substrate=None):
    if load_numerical_data and load_force_sum:
        raise ValueError("load_numerical_data and load_force_sum cannot be True at the same time")

    if train_test_sep_folder:
        train_folder = os.path.join(input_folder, 'train')
        test_folder = os.path.join(input_folder, 'test')
        if not (os.path.exists(train_folder) and os.path.exists(test_folder)):
            raise ValueError(f"train/test folders not found in {input_folder}")
        train_pairs = load_images_from_subfolders(train_folder, target_size=target_size,
                                                  load_numerical_data=load_numerical_data,
                                                  load_force_sum=load_force_sum,
                                                  return_metadata=return_metadata, substrate=substrate)
        val_pairs = load_images_from_subfolders(test_folder, target_size=target_size,
                                                load_numerical_data=load_numerical_data,
                                                load_force_sum=load_force_sum,
                                                return_metadata=return_metadata, substrate=substrate)
    else:
        image_pairs = load_images_from_subfolders(input_folder, target_size=target_size,
                                                  load_numerical_data=load_numerical_data,
                                                  load_force_sum=load_force_sum,
                                                  return_metadata=return_metadata, substrate=substrate)
        train_pairs, val_pairs = train_test_split(image_pairs, test_size=split_size, random_state=42)

    train_transform = None
    if use_augmentations:
        from .augmentations import AdvancedAugmentations
        train_transform = AdvancedAugmentations(target_size)

    train_dataset = ImageDataset(train_pairs, transform=train_transform, channel_first=channel_first,
                                 blur_heatmap=blur_heatmap, threshold=threshold, return_metadata=return_metadata)
    train_dataset.name = os.path.basename(input_folder)
    val_dataset = ImageDataset(val_pairs, channel_first=channel_first,
                               blur_heatmap=blur_heatmap, threshold=threshold, return_metadata=return_metadata)
    val_dataset.name = os.path.basename(input_folder)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader


def load_folder_data(folder_path, substrate=None, img_size=1024, blur_heatmap=False,
                     batch_size=2, threshold=0.0, return_metadata=False):
    val_pairs = load_images_from_subfolders(folder_path, target_size=img_size,
                                            load_numerical_data=False, load_force_sum=False,
                                            return_metadata=return_metadata, substrate=substrate)
    val_dataset = ImageDataset(val_pairs, channel_first=True, blur_heatmap=blur_heatmap,
                               threshold=threshold, return_metadata=return_metadata)
    val_dataset.name = os.path.basename(folder_path)
    return DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


def collect_image_paths(folder_path, exts=None):
    if exts is None:
        exts = {".tif", ".tiff", ".jpg", ".jpeg", ".png"}
    paths = []
    for root, _, files in os.walk(os.path.normpath(folder_path)):
        for f in files:
            if os.path.splitext(f)[1].lower() in exts:
                paths.append(os.path.join(root, f))
    return sorted(paths)


class BrightfieldOnlyDataset(Dataset):
    """Dataset of brightfield images only (no labels), for inference."""
    def __init__(self, folder_path, target_size=1024):
        self.paths = collect_image_paths(folder_path)
        self.target_size = (target_size, target_size) if isinstance(target_size, int) else target_size

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, i):
        x = load_image(self.paths[i], self.target_size)
        return torch.from_numpy(x).float().unsqueeze(0)


def load_brightfield_loader(folder_path, img_size=1024, batch_size=2):
    ds = BrightfieldOnlyDataset(folder_path, target_size=img_size)
    return DataLoader(ds, batch_size=batch_size, shuffle=False)