Shape2Force / data /cell_dataset.py
kaveh's picture
added one function as helper
f83fb5c
"""
Dataset and data loading for S2F training.
Expects folder structure: each subfolder has BF_001.tif (bright field), *_gray.jpg (heatmap), and optionally .txt (cell_area, sum_force).
"""
import os
import cv2
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from utils import config
def blur_force_map(force_map, ksize=25, sigma=10):
if ksize % 2 == 0:
ksize += 1
if force_map.dim() == 3:
force_map = force_map.unsqueeze(0)
device = force_map.device
force_map = force_map.cpu()
blurred_maps = []
for i in range(force_map.size(0)):
force_np = force_map[i, 0].numpy().astype(np.float32)
blurred = cv2.GaussianBlur(force_np, (ksize, ksize), sigmaX=sigma)
blurred_maps.append(blurred)
return torch.from_numpy(np.stack(blurred_maps)).to(device)
class ImageDataset(Dataset):
def __init__(self, image_pairs, transform=None, channel_first=True,
blur_heatmap=False, threshold=0.0, return_metadata=False):
self.image_pairs = image_pairs
self.transform = transform
self.channel_first = channel_first
self.blur_heatmap = blur_heatmap
self.threshold = threshold
self.return_metadata = return_metadata
def __len__(self):
return len(self.image_pairs)
def __getitem__(self, idx):
if self.return_metadata:
bf_image, hm_image, numbers, metadata = self.image_pairs[idx]
else:
bf_image, hm_image, numbers = self.image_pairs[idx]
if isinstance(numbers, tuple):
cell_area, sum_force = numbers
else:
cell_area = 0
sum_force = numbers
image = torch.from_numpy(bf_image).float().unsqueeze(0)
heatmap = torch.from_numpy(hm_image).float().unsqueeze(0)
if self.transform:
image, heatmap = self.transform(image, heatmap)
cell_area = torch.tensor(cell_area, dtype=torch.float32)
sum_force = torch.tensor(sum_force, dtype=torch.float32)
heatmap[heatmap <= self.threshold] = 0
if self.blur_heatmap:
heatmap = blur_force_map(heatmap)
if not self.channel_first:
image = image.permute(2, 1, 0)
heatmap = heatmap.permute(2, 1, 0)
if self.return_metadata:
return image, heatmap, cell_area, sum_force, metadata
return image, heatmap, cell_area, sum_force
def load_image(filepath, target_size):
img = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
if isinstance(target_size, int):
target_size = (target_size, target_size)
img = cv2.resize(img, target_size)
img = img / 255.0
return img.astype(np.float32)
def load_text_data(filepath):
with open(filepath, 'r') as f:
lines = [line.strip() for line in f if line.strip()]
cell_area_diff = float(lines[0].split(":")[1].strip()) * config.SCALE_FACTOR_AREA
sum_force_diff = float(lines[1].split(":")[1].strip()) * config.SCALE_FACTOR_FORCE
return (cell_area_diff, sum_force_diff)
def load_images_from_subfolders(root_folder, target_size, load_numerical_data=True,
load_force_sum=False, return_metadata=False, substrate=None):
paired_images = []
numerical_data = []
metadata = []
for subfolder in os.listdir(root_folder):
subfolder_path = os.path.join(root_folder, subfolder)
if not os.path.isdir(subfolder_path):
continue
bf_image_path = hm_image_path = txt_file_path = None
for filename in os.listdir(subfolder_path):
if filename.endswith("BF_001.tif"):
bf_image_path = os.path.join(subfolder_path, filename)
elif filename.endswith("_gray.jpg"):
hm_image_path = os.path.join(subfolder_path, filename)
elif filename.endswith(".txt"):
txt_file_path = os.path.join(subfolder_path, filename)
if return_metadata:
if substrate is None:
from utils.substrate_settings import list_substrates
raise ValueError("substrate must be passed when return_metadata=True. Options: " +
", ".join(list_substrates()))
metadata.append({'folder_name': subfolder, 'substrate': substrate, 'root_folder': root_folder})
if load_numerical_data:
if bf_image_path and hm_image_path and txt_file_path:
paired_images.append((bf_image_path, hm_image_path))
numerical_data.append(load_text_data(txt_file_path))
elif load_force_sum:
if bf_image_path and hm_image_path:
paired_images.append((bf_image_path, hm_image_path))
hm = load_image(hm_image_path, target_size)
numerical_data.append((0, float(np.sum(hm)) * config.SCALE_FACTOR_FORCE))
else:
if bf_image_path and hm_image_path:
paired_images.append((bf_image_path, hm_image_path))
with ThreadPoolExecutor() as executor:
bf_loaded = list(executor.map(lambda p: load_image(p[0], target_size), paired_images))
hm_loaded = list(executor.map(lambda p: load_image(p[1], target_size), paired_images))
if not numerical_data:
numerical_data = [(0, 0)] * len(bf_loaded)
if return_metadata:
return list(zip(bf_loaded, hm_loaded, numerical_data, metadata))
return list(zip(bf_loaded, hm_loaded, numerical_data))
def prepare_data(input_folder, batch_size=8, target_size=(1024, 1024), split_size=0.2,
use_augmentations=True, train_test_sep_folder=True, channel_first=True,
load_numerical_data=False, load_force_sum=False, blur_heatmap=False,
threshold=0.0, return_metadata=False, substrate=None):
if load_numerical_data and load_force_sum:
raise ValueError("load_numerical_data and load_force_sum cannot be True at the same time")
if train_test_sep_folder:
train_folder = os.path.join(input_folder, 'train')
test_folder = os.path.join(input_folder, 'test')
if not (os.path.exists(train_folder) and os.path.exists(test_folder)):
raise ValueError(f"train/test folders not found in {input_folder}")
train_pairs = load_images_from_subfolders(train_folder, target_size=target_size,
load_numerical_data=load_numerical_data,
load_force_sum=load_force_sum,
return_metadata=return_metadata, substrate=substrate)
val_pairs = load_images_from_subfolders(test_folder, target_size=target_size,
load_numerical_data=load_numerical_data,
load_force_sum=load_force_sum,
return_metadata=return_metadata, substrate=substrate)
else:
image_pairs = load_images_from_subfolders(input_folder, target_size=target_size,
load_numerical_data=load_numerical_data,
load_force_sum=load_force_sum,
return_metadata=return_metadata, substrate=substrate)
train_pairs, val_pairs = train_test_split(image_pairs, test_size=split_size, random_state=42)
train_transform = None
if use_augmentations:
from .augmentations import AdvancedAugmentations
train_transform = AdvancedAugmentations(target_size)
train_dataset = ImageDataset(train_pairs, transform=train_transform, channel_first=channel_first,
blur_heatmap=blur_heatmap, threshold=threshold, return_metadata=return_metadata)
train_dataset.name = os.path.basename(input_folder)
val_dataset = ImageDataset(val_pairs, channel_first=channel_first,
blur_heatmap=blur_heatmap, threshold=threshold, return_metadata=return_metadata)
val_dataset.name = os.path.basename(input_folder)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
return train_loader, val_loader
def load_folder_data(folder_path, substrate=None, img_size=1024, blur_heatmap=False,
batch_size=2, threshold=0.0, return_metadata=False):
val_pairs = load_images_from_subfolders(folder_path, target_size=img_size,
load_numerical_data=False, load_force_sum=False,
return_metadata=return_metadata, substrate=substrate)
val_dataset = ImageDataset(val_pairs, channel_first=True, blur_heatmap=blur_heatmap,
threshold=threshold, return_metadata=return_metadata)
val_dataset.name = os.path.basename(folder_path)
return DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
def collect_image_paths(folder_path, exts=None):
if exts is None:
exts = {".tif", ".tiff", ".jpg", ".jpeg", ".png"}
paths = []
for root, _, files in os.walk(os.path.normpath(folder_path)):
for f in files:
if os.path.splitext(f)[1].lower() in exts:
paths.append(os.path.join(root, f))
return sorted(paths)
class BrightfieldOnlyDataset(Dataset):
"""Dataset of brightfield images only (no labels), for inference."""
def __init__(self, folder_path, target_size=1024):
self.paths = collect_image_paths(folder_path)
self.target_size = (target_size, target_size) if isinstance(target_size, int) else target_size
def __len__(self):
return len(self.paths)
def __getitem__(self, i):
x = load_image(self.paths[i], self.target_size)
return torch.from_numpy(x).float().unsqueeze(0)
def load_brightfield_loader(folder_path, img_size=1024, batch_size=2):
ds = BrightfieldOnlyDataset(folder_path, target_size=img_size)
return DataLoader(ds, batch_size=batch_size, shuffle=False)