""" Dataset and data loading for S2F training. Expects folder structure: each subfolder has BF_001.tif (bright field), *_gray.jpg (heatmap), and optionally .txt (cell_area, sum_force). """ import os import cv2 import torch from torch.utils.data import Dataset, DataLoader from sklearn.model_selection import train_test_split from concurrent.futures import ThreadPoolExecutor import numpy as np from utils import config def blur_force_map(force_map, ksize=25, sigma=10): if ksize % 2 == 0: ksize += 1 if force_map.dim() == 3: force_map = force_map.unsqueeze(0) device = force_map.device force_map = force_map.cpu() blurred_maps = [] for i in range(force_map.size(0)): force_np = force_map[i, 0].numpy().astype(np.float32) blurred = cv2.GaussianBlur(force_np, (ksize, ksize), sigmaX=sigma) blurred_maps.append(blurred) return torch.from_numpy(np.stack(blurred_maps)).to(device) class ImageDataset(Dataset): def __init__(self, image_pairs, transform=None, channel_first=True, blur_heatmap=False, threshold=0.0, return_metadata=False): self.image_pairs = image_pairs self.transform = transform self.channel_first = channel_first self.blur_heatmap = blur_heatmap self.threshold = threshold self.return_metadata = return_metadata def __len__(self): return len(self.image_pairs) def __getitem__(self, idx): if self.return_metadata: bf_image, hm_image, numbers, metadata = self.image_pairs[idx] else: bf_image, hm_image, numbers = self.image_pairs[idx] if isinstance(numbers, tuple): cell_area, sum_force = numbers else: cell_area = 0 sum_force = numbers image = torch.from_numpy(bf_image).float().unsqueeze(0) heatmap = torch.from_numpy(hm_image).float().unsqueeze(0) if self.transform: image, heatmap = self.transform(image, heatmap) cell_area = torch.tensor(cell_area, dtype=torch.float32) sum_force = torch.tensor(sum_force, dtype=torch.float32) heatmap[heatmap <= self.threshold] = 0 if self.blur_heatmap: heatmap = blur_force_map(heatmap) if not self.channel_first: image = image.permute(2, 1, 0) heatmap = heatmap.permute(2, 1, 0) if self.return_metadata: return image, heatmap, cell_area, sum_force, metadata return image, heatmap, cell_area, sum_force def load_image(filepath, target_size): img = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE) if isinstance(target_size, int): target_size = (target_size, target_size) img = cv2.resize(img, target_size) img = img / 255.0 return img.astype(np.float32) def load_text_data(filepath): with open(filepath, 'r') as f: lines = [line.strip() for line in f if line.strip()] cell_area_diff = float(lines[0].split(":")[1].strip()) * config.SCALE_FACTOR_AREA sum_force_diff = float(lines[1].split(":")[1].strip()) * config.SCALE_FACTOR_FORCE return (cell_area_diff, sum_force_diff) def load_images_from_subfolders(root_folder, target_size, load_numerical_data=True, load_force_sum=False, return_metadata=False, substrate=None): paired_images = [] numerical_data = [] metadata = [] for subfolder in os.listdir(root_folder): subfolder_path = os.path.join(root_folder, subfolder) if not os.path.isdir(subfolder_path): continue bf_image_path = hm_image_path = txt_file_path = None for filename in os.listdir(subfolder_path): if filename.endswith("BF_001.tif"): bf_image_path = os.path.join(subfolder_path, filename) elif filename.endswith("_gray.jpg"): hm_image_path = os.path.join(subfolder_path, filename) elif filename.endswith(".txt"): txt_file_path = os.path.join(subfolder_path, filename) if return_metadata: if substrate is None: from utils.substrate_settings import list_substrates raise ValueError("substrate must be passed when return_metadata=True. Options: " + ", ".join(list_substrates())) metadata.append({'folder_name': subfolder, 'substrate': substrate, 'root_folder': root_folder}) if load_numerical_data: if bf_image_path and hm_image_path and txt_file_path: paired_images.append((bf_image_path, hm_image_path)) numerical_data.append(load_text_data(txt_file_path)) elif load_force_sum: if bf_image_path and hm_image_path: paired_images.append((bf_image_path, hm_image_path)) hm = load_image(hm_image_path, target_size) numerical_data.append((0, float(np.sum(hm)) * config.SCALE_FACTOR_FORCE)) else: if bf_image_path and hm_image_path: paired_images.append((bf_image_path, hm_image_path)) with ThreadPoolExecutor() as executor: bf_loaded = list(executor.map(lambda p: load_image(p[0], target_size), paired_images)) hm_loaded = list(executor.map(lambda p: load_image(p[1], target_size), paired_images)) if not numerical_data: numerical_data = [(0, 0)] * len(bf_loaded) if return_metadata: return list(zip(bf_loaded, hm_loaded, numerical_data, metadata)) return list(zip(bf_loaded, hm_loaded, numerical_data)) def prepare_data(input_folder, batch_size=8, target_size=(1024, 1024), split_size=0.2, use_augmentations=True, train_test_sep_folder=True, channel_first=True, load_numerical_data=False, load_force_sum=False, blur_heatmap=False, threshold=0.0, return_metadata=False, substrate=None): if load_numerical_data and load_force_sum: raise ValueError("load_numerical_data and load_force_sum cannot be True at the same time") if train_test_sep_folder: train_folder = os.path.join(input_folder, 'train') test_folder = os.path.join(input_folder, 'test') if not (os.path.exists(train_folder) and os.path.exists(test_folder)): raise ValueError(f"train/test folders not found in {input_folder}") train_pairs = load_images_from_subfolders(train_folder, target_size=target_size, load_numerical_data=load_numerical_data, load_force_sum=load_force_sum, return_metadata=return_metadata, substrate=substrate) val_pairs = load_images_from_subfolders(test_folder, target_size=target_size, load_numerical_data=load_numerical_data, load_force_sum=load_force_sum, return_metadata=return_metadata, substrate=substrate) else: image_pairs = load_images_from_subfolders(input_folder, target_size=target_size, load_numerical_data=load_numerical_data, load_force_sum=load_force_sum, return_metadata=return_metadata, substrate=substrate) train_pairs, val_pairs = train_test_split(image_pairs, test_size=split_size, random_state=42) train_transform = None if use_augmentations: from .augmentations import AdvancedAugmentations train_transform = AdvancedAugmentations(target_size) train_dataset = ImageDataset(train_pairs, transform=train_transform, channel_first=channel_first, blur_heatmap=blur_heatmap, threshold=threshold, return_metadata=return_metadata) train_dataset.name = os.path.basename(input_folder) val_dataset = ImageDataset(val_pairs, channel_first=channel_first, blur_heatmap=blur_heatmap, threshold=threshold, return_metadata=return_metadata) val_dataset.name = os.path.basename(input_folder) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) return train_loader, val_loader def load_folder_data(folder_path, substrate=None, img_size=1024, blur_heatmap=False, batch_size=2, threshold=0.0, return_metadata=False): val_pairs = load_images_from_subfolders(folder_path, target_size=img_size, load_numerical_data=False, load_force_sum=False, return_metadata=return_metadata, substrate=substrate) val_dataset = ImageDataset(val_pairs, channel_first=True, blur_heatmap=blur_heatmap, threshold=threshold, return_metadata=return_metadata) val_dataset.name = os.path.basename(folder_path) return DataLoader(val_dataset, batch_size=batch_size, shuffle=False) def collect_image_paths(folder_path, exts=None): if exts is None: exts = {".tif", ".tiff", ".jpg", ".jpeg", ".png"} paths = [] for root, _, files in os.walk(os.path.normpath(folder_path)): for f in files: if os.path.splitext(f)[1].lower() in exts: paths.append(os.path.join(root, f)) return sorted(paths) class BrightfieldOnlyDataset(Dataset): """Dataset of brightfield images only (no labels), for inference.""" def __init__(self, folder_path, target_size=1024): self.paths = collect_image_paths(folder_path) self.target_size = (target_size, target_size) if isinstance(target_size, int) else target_size def __len__(self): return len(self.paths) def __getitem__(self, i): x = load_image(self.paths[i], self.target_size) return torch.from_numpy(x).float().unsqueeze(0) def load_brightfield_loader(folder_path, img_size=1024, batch_size=2): ds = BrightfieldOnlyDataset(folder_path, target_size=img_size) return DataLoader(ds, batch_size=batch_size, shuffle=False)