""" Crop-based dataset loaders for training on random 512x512 crops from native 1024x1024. Both BCI and MIST variants share the same crop + augmentation logic. UNI features are extracted on-the-fly on GPU (not pre-computed). """ import os import random from pathlib import Path from typing import Optional, Tuple import torch from torch.utils.data import Dataset, DataLoader from PIL import Image import torchvision.transforms as T import torchvision.transforms.functional as TF import pytorch_lightning as pl class CropPairedDataset(Dataset): """Base class for random-crop paired H&E/IHC datasets. Loads 1024x1024 images, takes a random 512x512 crop (same position for both), and returns the crop + a UNI-ready version for on-the-fly feature extraction. """ def __init__( self, he_dir: str, ihc_dir: str, image_size: Tuple[int, int] = (512, 512), crop_size: int = 512, augment: bool = False, labels: Optional[list] = None, null_class: int = 4, ): self.he_dir = Path(he_dir) self.ihc_dir = Path(ihc_dir) self.image_size = image_size self.crop_size = crop_size self.augment = augment self.null_class = null_class self.labels = labels # UNI normalization (ImageNet stats, 224x224 per sub-crop) self.uni_crop_transform = T.Compose([ T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def _random_crop_pair(self, he_img, ihc_img): """Take the same random 512x512 crop from both images.""" w, h = he_img.size if w < self.crop_size or h < self.crop_size: raise ValueError( f"Image size {w}x{h} smaller than crop size {self.crop_size}" ) if w == self.crop_size and h == self.crop_size: return he_img, ihc_img left = random.randint(0, w - self.crop_size) top = random.randint(0, h - self.crop_size) he_crop = he_img.crop((left, top, left + self.crop_size, top + self.crop_size)) ihc_crop = ihc_img.crop((left, top, left + self.crop_size, top + self.crop_size)) return he_crop, ihc_crop def _prepare_uni_sub_crops(self, he_pil): """Split 512x512 PIL crop into 4x4 sub-crops, each resized to 224x224 with UNI normalization. Returns: [16, 3, 224, 224] tensor ready for UNI forward pass on GPU. """ w, h = he_pil.size num_crops = 4 cw = w // num_crops ch = h // num_crops sub_crops = [] for i in range(num_crops): for j in range(num_crops): left = j * cw top = i * ch sub = he_pil.crop((left, top, left + cw, top + ch)) sub_crops.append(self.uni_crop_transform(sub)) return torch.stack(sub_crops) # [16, 3, 224, 224] def _apply_paired_augmentations(self, he_img, ihc_img): """Apply identical spatial transforms to both images.""" if random.random() > 0.5: he_img = TF.hflip(he_img) ihc_img = TF.hflip(ihc_img) if random.random() > 0.5: he_img = TF.vflip(he_img) ihc_img = TF.vflip(ihc_img) if random.random() > 0.5: k = random.choice([1, 2, 3]) he_img = TF.rotate(he_img, k * 90) ihc_img = TF.rotate(ihc_img, k * 90) if random.random() > 0.7: angle = random.uniform(-15, 15) translate = [random.uniform(-0.05, 0.05) * self.image_size[1], random.uniform(-0.05, 0.05) * self.image_size[0]] scale = random.uniform(0.9, 1.1) he_img = TF.affine(he_img, angle, translate, scale, shear=0, interpolation=T.InterpolationMode.BILINEAR) ihc_img = TF.affine(ihc_img, angle, translate, scale, shear=0, interpolation=T.InterpolationMode.BILINEAR) return he_img, ihc_img def _apply_he_color_augmentation(self, he_img): """Apply color jitter to H&E only (simulates staining variability).""" if random.random() > 0.5: he_img = TF.adjust_brightness(he_img, random.uniform(0.9, 1.1)) if random.random() > 0.5: he_img = TF.adjust_contrast(he_img, random.uniform(0.9, 1.1)) if random.random() > 0.5: he_img = TF.adjust_saturation(he_img, random.uniform(0.9, 1.1)) return he_img def _process_pair(self, he_img, ihc_img, label, filename): """Common processing: crop -> augment -> tensorize -> UNI sub-crops. Returns: (he_tensor, ihc_tensor, uni_sub_crops, label, filename) - he_tensor: [3, 512, 512] in [-1, 1] - ihc_tensor: [3, 512, 512] in [-1, 1] - uni_sub_crops: [16, 3, 224, 224] with ImageNet normalization """ # Random crop (same position for both) he_crop, ihc_crop = self._random_crop_pair(he_img, ihc_img) # Augmentations (applied to PIL before UNI extraction, so features match) if self.augment: he_crop, ihc_crop = self._apply_paired_augmentations(he_crop, ihc_crop) he_aug = self._apply_he_color_augmentation(he_crop) else: he_aug = he_crop # Prepare UNI sub-crops from the augmented H&E crop uni_sub_crops = self._prepare_uni_sub_crops(he_aug) # Convert to training tensors [-1, 1] he_tensor = TF.normalize(TF.to_tensor(he_aug), [0.5]*3, [0.5]*3) ihc_tensor = TF.normalize(TF.to_tensor(ihc_crop), [0.5]*3, [0.5]*3) return he_tensor, ihc_tensor, uni_sub_crops, label, filename class BCICropDataset(CropPairedDataset): """BCI dataset with random 512 crops from 1024x1024 native images.""" HER2_LABEL_MAP = {'0': 0, '1+': 1, '2+': 2, '3+': 3} def __init__(self, he_dir, ihc_dir, image_size=(512, 512), crop_size=512, augment=False): super().__init__(he_dir, ihc_dir, image_size, crop_size, augment) self.he_images = sorted([f for f in os.listdir(he_dir) if f.endswith('.png')]) self.ihc_images = sorted([f for f in os.listdir(ihc_dir) if f.endswith('.png')]) assert len(self.he_images) == len(self.ihc_images) self.labels = [self._parse_label(f) for f in self.he_images] from collections import Counter dist = Counter(self.labels) print(f"BCI Crop Dataset: {len(self)} images, classes: {dict(sorted(dist.items()))}") def _parse_label(self, filename): parts = filename.replace('.png', '').split('_') if len(parts) >= 3: level = parts[2] if level in self.HER2_LABEL_MAP: return self.HER2_LABEL_MAP[level] raise ValueError(f"Cannot parse label from: {filename}") def __len__(self): return len(self.he_images) def __getitem__(self, idx): filename = self.he_images[idx] he_img = Image.open(self.he_dir / filename).convert('RGB') ihc_img = Image.open(self.ihc_dir / self.ihc_images[idx]).convert('RGB') return self._process_pair(he_img, ihc_img, self.labels[idx], filename) class MISTCropDataset(CropPairedDataset): """MIST dataset with random 512 crops from 1024x1024 native images.""" def __init__(self, he_dir, ihc_dir, image_size=(512, 512), crop_size=512, augment=False, null_class=4): super().__init__(he_dir, ihc_dir, image_size, crop_size, augment, null_class=null_class) valid_exts = ('.jpg', '.jpeg', '.png') self.he_images = sorted([f for f in os.listdir(he_dir) if f.lower().endswith(valid_exts)]) self.ihc_images = sorted([f for f in os.listdir(ihc_dir) if f.lower().endswith(valid_exts)]) # Verify pairing he_stems = {Path(f).stem for f in self.he_images} ihc_stems = {Path(f).stem for f in self.ihc_images} if he_stems != ihc_stems: common = he_stems & ihc_stems self.he_images = sorted([f for f in self.he_images if Path(f).stem in common]) self.ihc_images = sorted([f for f in self.ihc_images if Path(f).stem in common]) print(f"Using {len(self.he_images)} matched pairs") print(f"MIST Crop Dataset: {len(self)} images (null_class={null_class})") def __len__(self): return len(self.he_images) def __getitem__(self, idx): filename = self.he_images[idx] he_img = Image.open(self.he_dir / filename).convert('RGB') ihc_img = Image.open(self.ihc_dir / self.ihc_images[idx]).convert('RGB') return self._process_pair(he_img, ihc_img, self.null_class, filename) class BCICropDataModule(pl.LightningDataModule): def __init__(self, data_dir, batch_size=4, num_workers=4, image_size=(512, 512), crop_size=512): super().__init__() self.data_dir = Path(data_dir) self.batch_size = batch_size self.num_workers = num_workers self.image_size = image_size self.crop_size = crop_size def setup(self, stage=None): if stage == 'fit' or stage is None: self.train_dataset = BCICropDataset( he_dir=self.data_dir / 'HE' / 'train', ihc_dir=self.data_dir / 'IHC' / 'train', image_size=self.image_size, crop_size=self.crop_size, augment=True, ) if stage in ('fit', 'validate', 'test') or stage is None: self.val_dataset = BCICropDataset( he_dir=self.data_dir / 'HE' / 'test', ihc_dir=self.data_dir / 'IHC' / 'test', image_size=self.image_size, crop_size=self.crop_size, augment=False, ) def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=True, persistent_workers=self.num_workers > 0, ) def val_dataloader(self): return DataLoader( self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True, persistent_workers=self.num_workers > 0, ) def test_dataloader(self): return self.val_dataloader() class MISTCropDataModule(pl.LightningDataModule): def __init__(self, data_dir, batch_size=4, num_workers=4, image_size=(512, 512), crop_size=512, null_class=4): super().__init__() self.data_dir = Path(data_dir) self.batch_size = batch_size self.num_workers = num_workers self.image_size = image_size self.crop_size = crop_size self.null_class = null_class def setup(self, stage=None): if stage == 'fit' or stage is None: self.train_dataset = MISTCropDataset( he_dir=self.data_dir / 'trainA', ihc_dir=self.data_dir / 'trainB', image_size=self.image_size, crop_size=self.crop_size, augment=True, null_class=self.null_class, ) if stage in ('fit', 'validate', 'test') or stage is None: self.val_dataset = MISTCropDataset( he_dir=self.data_dir / 'valA', ihc_dir=self.data_dir / 'valB', image_size=self.image_size, crop_size=self.crop_size, augment=False, null_class=self.null_class, ) def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=True, persistent_workers=self.num_workers > 0, ) def val_dataloader(self): return DataLoader( self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True, persistent_workers=self.num_workers > 0, ) def test_dataloader(self): return self.val_dataloader()