Spaces:
Running
Running
| """ | |
| Multi-stain dataset for training a single model on all MIST IHC stains. | |
| Combines HER2, Ki67, ER, PR into one dataset, returning a stain label (0-3) | |
| instead of a class label. Reuses the same crop + UNI sub-crop pipeline | |
| from CropPairedDataset. | |
| Stain label mapping: | |
| 0 = HER2, 1 = Ki67, 2 = ER, 3 = PR, 4 = null (CFG dropout) | |
| """ | |
| import os | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from PIL import Image | |
| import pytorch_lightning as pl | |
| from src.data.bci_dataset import CropPairedDataset | |
| STAIN_TO_LABEL = {'HER2': 0, 'Ki67': 1, 'ER': 2, 'PR': 3} | |
| LABEL_TO_STAIN = {v: k for k, v in STAIN_TO_LABEL.items()} | |
| class MISTMultiStainCropDataset(CropPairedDataset): | |
| """Multi-stain MIST dataset with random 512 crops from native 1024x1024. | |
| Loads all 4 MIST stains into a single dataset. Each sample returns a | |
| stain label (0-3) as the conditioning signal, reusing the class embedding | |
| slot in the generator. | |
| """ | |
| def __init__( | |
| self, | |
| base_dir: str, | |
| stains: List[str], | |
| split: str = 'train', | |
| image_size: Tuple[int, int] = (512, 512), | |
| crop_size: int = 512, | |
| augment: bool = False, | |
| null_class: int = 4, | |
| ): | |
| super().__init__( | |
| he_dir='.', # placeholder, we override __getitem__ | |
| ihc_dir='.', | |
| image_size=image_size, | |
| crop_size=crop_size, | |
| augment=augment, | |
| null_class=null_class, | |
| ) | |
| self.base_dir = Path(base_dir) | |
| self.samples = [] # (he_path, ihc_path, stain_label) | |
| split_he = 'trainA' if split == 'train' else 'valA' | |
| split_ihc = 'trainB' if split == 'train' else 'valB' | |
| valid_exts = ('.jpg', '.jpeg', '.png') | |
| for stain in stains: | |
| if stain not in STAIN_TO_LABEL: | |
| raise ValueError(f"Unknown stain: {stain}. Must be one of {list(STAIN_TO_LABEL.keys())}") | |
| stain_label = STAIN_TO_LABEL[stain] | |
| he_dir = self.base_dir / stain / 'TrainValAB' / split_he | |
| ihc_dir = self.base_dir / stain / 'TrainValAB' / split_ihc | |
| if not he_dir.exists(): | |
| raise FileNotFoundError(f"H&E directory not found: {he_dir}") | |
| if not ihc_dir.exists(): | |
| raise FileNotFoundError(f"IHC directory not found: {ihc_dir}") | |
| he_files = sorted([f for f in os.listdir(he_dir) | |
| if f.lower().endswith(valid_exts)]) | |
| ihc_files = sorted([f for f in os.listdir(ihc_dir) | |
| if f.lower().endswith(valid_exts)]) | |
| # Match by stem (H&E may be .jpg, IHC may be .png) | |
| he_stems = {Path(f).stem: f for f in he_files} | |
| ihc_stems = {Path(f).stem: f for f in ihc_files} | |
| common = sorted(set(he_stems.keys()) & set(ihc_stems.keys())) | |
| for stem in common: | |
| self.samples.append(( | |
| he_dir / he_stems[stem], | |
| ihc_dir / ihc_stems[stem], | |
| stain_label, | |
| )) | |
| print(f" {stain} ({split}): {len(common)} pairs") | |
| # Per-stain counts for logging | |
| from collections import Counter | |
| dist = Counter(s[2] for s in self.samples) | |
| stain_counts = {LABEL_TO_STAIN[k]: v for k, v in sorted(dist.items())} | |
| print(f"Multi-Stain Crop Dataset ({split}): {len(self.samples)} total | {stain_counts}") | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, idx): | |
| he_path, ihc_path, stain_label = self.samples[idx] | |
| he_img = Image.open(he_path).convert('RGB') | |
| ihc_img = Image.open(ihc_path).convert('RGB') | |
| return self._process_pair(he_img, ihc_img, stain_label, he_path.name) | |
| class MISTMultiStainCropDataModule(pl.LightningDataModule): | |
| """Lightning DataModule for multi-stain MIST training.""" | |
| def __init__( | |
| self, | |
| base_dir: str, | |
| stains: Optional[List[str]] = None, | |
| batch_size: int = 4, | |
| num_workers: int = 4, | |
| image_size: Tuple[int, int] = (512, 512), | |
| crop_size: int = 512, | |
| null_class: int = 4, | |
| ): | |
| super().__init__() | |
| self.base_dir = base_dir | |
| self.stains = stains or ['HER2', 'Ki67', 'ER', 'PR'] | |
| self.batch_size = batch_size | |
| self.num_workers = num_workers | |
| self.image_size = image_size | |
| self.crop_size = crop_size | |
| self.null_class = null_class | |
| def setup(self, stage=None): | |
| if stage == 'fit' or stage is None: | |
| self.train_dataset = MISTMultiStainCropDataset( | |
| base_dir=self.base_dir, | |
| stains=self.stains, | |
| split='train', | |
| image_size=self.image_size, | |
| crop_size=self.crop_size, | |
| augment=True, | |
| null_class=self.null_class, | |
| ) | |
| if stage in ('fit', 'validate', 'test') or stage is None: | |
| self.val_dataset = MISTMultiStainCropDataset( | |
| base_dir=self.base_dir, | |
| stains=self.stains, | |
| split='val', | |
| image_size=self.image_size, | |
| crop_size=self.crop_size, | |
| augment=False, | |
| null_class=self.null_class, | |
| ) | |
| def train_dataloader(self): | |
| return DataLoader( | |
| self.train_dataset, batch_size=self.batch_size, shuffle=True, | |
| num_workers=self.num_workers, pin_memory=True, | |
| persistent_workers=self.num_workers > 0, | |
| ) | |
| def val_dataloader(self): | |
| return DataLoader( | |
| self.val_dataset, batch_size=self.batch_size, shuffle=False, | |
| num_workers=self.num_workers, pin_memory=True, | |
| persistent_workers=self.num_workers > 0, | |
| ) | |
| def test_dataloader(self): | |
| return self.val_dataloader() | |