UNIStainNet / src /data /mist_dataset.py
faceless-void's picture
Upload folder using huggingface_hub
4db9215 verified
"""
Multi-stain dataset for training a single model on all MIST IHC stains.
Combines HER2, Ki67, ER, PR into one dataset, returning a stain label (0-3)
instead of a class label. Reuses the same crop + UNI sub-crop pipeline
from CropPairedDataset.
Stain label mapping:
0 = HER2, 1 = Ki67, 2 = ER, 3 = PR, 4 = null (CFG dropout)
"""
import os
from pathlib import Path
from typing import List, Optional, Tuple
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pytorch_lightning as pl
from src.data.bci_dataset import CropPairedDataset
STAIN_TO_LABEL = {'HER2': 0, 'Ki67': 1, 'ER': 2, 'PR': 3}
LABEL_TO_STAIN = {v: k for k, v in STAIN_TO_LABEL.items()}
class MISTMultiStainCropDataset(CropPairedDataset):
"""Multi-stain MIST dataset with random 512 crops from native 1024x1024.
Loads all 4 MIST stains into a single dataset. Each sample returns a
stain label (0-3) as the conditioning signal, reusing the class embedding
slot in the generator.
"""
def __init__(
self,
base_dir: str,
stains: List[str],
split: str = 'train',
image_size: Tuple[int, int] = (512, 512),
crop_size: int = 512,
augment: bool = False,
null_class: int = 4,
):
super().__init__(
he_dir='.', # placeholder, we override __getitem__
ihc_dir='.',
image_size=image_size,
crop_size=crop_size,
augment=augment,
null_class=null_class,
)
self.base_dir = Path(base_dir)
self.samples = [] # (he_path, ihc_path, stain_label)
split_he = 'trainA' if split == 'train' else 'valA'
split_ihc = 'trainB' if split == 'train' else 'valB'
valid_exts = ('.jpg', '.jpeg', '.png')
for stain in stains:
if stain not in STAIN_TO_LABEL:
raise ValueError(f"Unknown stain: {stain}. Must be one of {list(STAIN_TO_LABEL.keys())}")
stain_label = STAIN_TO_LABEL[stain]
he_dir = self.base_dir / stain / 'TrainValAB' / split_he
ihc_dir = self.base_dir / stain / 'TrainValAB' / split_ihc
if not he_dir.exists():
raise FileNotFoundError(f"H&E directory not found: {he_dir}")
if not ihc_dir.exists():
raise FileNotFoundError(f"IHC directory not found: {ihc_dir}")
he_files = sorted([f for f in os.listdir(he_dir)
if f.lower().endswith(valid_exts)])
ihc_files = sorted([f for f in os.listdir(ihc_dir)
if f.lower().endswith(valid_exts)])
# Match by stem (H&E may be .jpg, IHC may be .png)
he_stems = {Path(f).stem: f for f in he_files}
ihc_stems = {Path(f).stem: f for f in ihc_files}
common = sorted(set(he_stems.keys()) & set(ihc_stems.keys()))
for stem in common:
self.samples.append((
he_dir / he_stems[stem],
ihc_dir / ihc_stems[stem],
stain_label,
))
print(f" {stain} ({split}): {len(common)} pairs")
# Per-stain counts for logging
from collections import Counter
dist = Counter(s[2] for s in self.samples)
stain_counts = {LABEL_TO_STAIN[k]: v for k, v in sorted(dist.items())}
print(f"Multi-Stain Crop Dataset ({split}): {len(self.samples)} total | {stain_counts}")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
he_path, ihc_path, stain_label = self.samples[idx]
he_img = Image.open(he_path).convert('RGB')
ihc_img = Image.open(ihc_path).convert('RGB')
return self._process_pair(he_img, ihc_img, stain_label, he_path.name)
class MISTMultiStainCropDataModule(pl.LightningDataModule):
"""Lightning DataModule for multi-stain MIST training."""
def __init__(
self,
base_dir: str,
stains: Optional[List[str]] = None,
batch_size: int = 4,
num_workers: int = 4,
image_size: Tuple[int, int] = (512, 512),
crop_size: int = 512,
null_class: int = 4,
):
super().__init__()
self.base_dir = base_dir
self.stains = stains or ['HER2', 'Ki67', 'ER', 'PR']
self.batch_size = batch_size
self.num_workers = num_workers
self.image_size = image_size
self.crop_size = crop_size
self.null_class = null_class
def setup(self, stage=None):
if stage == 'fit' or stage is None:
self.train_dataset = MISTMultiStainCropDataset(
base_dir=self.base_dir,
stains=self.stains,
split='train',
image_size=self.image_size,
crop_size=self.crop_size,
augment=True,
null_class=self.null_class,
)
if stage in ('fit', 'validate', 'test') or stage is None:
self.val_dataset = MISTMultiStainCropDataset(
base_dir=self.base_dir,
stains=self.stains,
split='val',
image_size=self.image_size,
crop_size=self.crop_size,
augment=False,
null_class=self.null_class,
)
def train_dataloader(self):
return DataLoader(
self.train_dataset, batch_size=self.batch_size, shuffle=True,
num_workers=self.num_workers, pin_memory=True,
persistent_workers=self.num_workers > 0,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset, batch_size=self.batch_size, shuffle=False,
num_workers=self.num_workers, pin_memory=True,
persistent_workers=self.num_workers > 0,
)
def test_dataloader(self):
return self.val_dataloader()