"""dataset_loader.py — Data loading for NSGF/NSGF++ experiments. Handles: - 2D synthetic datasets (8gaussians, moons, scurve, checkerboard) - MNIST / CIFAR-10 for image experiments - Source distributions (standard Gaussian) Reference: arXiv:2401.14069, Appendix E.1 and E.2 """ import math import numpy as np import torch from torch.utils.data import DataLoader, TensorDataset from sklearn.datasets import make_moons, make_s_curve # ============================================================ # 2D Synthetic Datasets (following Tong et al. 2023 / Grathwohl et al. 2018) # ============================================================ def sample_8gaussians(n: int, scale: float = 4.0, std: float = 0.5) -> torch.Tensor: """8 Gaussian modes arranged in a circle.""" centers = [] for i in range(8): angle = 2 * math.pi * i / 8 centers.append((scale * math.cos(angle), scale * math.sin(angle))) centers = np.array(centers) idx = np.random.randint(0, 8, n) data = centers[idx] + np.random.randn(n, 2) * std return torch.FloatTensor(data) def sample_moons(n: int, noise: float = 0.05) -> torch.Tensor: """Two interleaving half-circles (scikit-learn moons).""" data, _ = make_moons(n_samples=n, noise=noise) data = data * 3.0 - np.array([1.0, 0.0]) return torch.FloatTensor(data) def sample_scurve(n: int, noise: float = 0.0) -> torch.Tensor: """S-curve projected to 2D.""" data, _ = make_s_curve(n_samples=n, noise=noise) data = data[:, [0, 2]] * 3.0 return torch.FloatTensor(data) def sample_checkerboard(n: int) -> torch.Tensor: """4x4 checkerboard pattern.""" x1 = np.random.rand(n) * 4 - 2 x2_ = np.random.rand(n) - np.random.randint(0, 2, n) * 2 x2 = x2_ + (np.floor(x1) % 2) data = np.column_stack([x1, x2]) * 2 return torch.FloatTensor(data) def sample_8gaussians_moons(n: int) -> torch.Tensor: """Mixture: half from 8gaussians, half from moons.""" n1 = n // 2 n2 = n - n1 g = sample_8gaussians(n1) m = sample_moons(n2) data = torch.cat([g, m], dim=0) perm = torch.randperm(n) return data[perm] DATASET_2D = { "8gaussians": sample_8gaussians, "moons": sample_moons, "scurve": sample_scurve, "checkerboard": sample_checkerboard, "8gaussians_moons": sample_8gaussians_moons, } def get_2d_dataset(name: str, n: int) -> torch.Tensor: if name not in DATASET_2D: raise ValueError(f"Unknown 2D dataset: {name}. Available: {list(DATASET_2D.keys())}") return DATASET_2D[name](n) def sample_source_2d(n: int, dim: int = 2) -> torch.Tensor: return torch.randn(n, dim) # ============================================================ # Image Datasets (MNIST, CIFAR-10) # ============================================================ def get_image_dataloader( dataset_name: str, batch_size: int, train: bool = True, data_root: str = "./data", num_workers: int = 2, normalize_range: tuple = (-1.0, 1.0), ) -> DataLoader: import torchvision import torchvision.transforms as T lo, hi = normalize_range transforms_list = [T.ToTensor()] transforms_list.append(T.Normalize( mean=[0.5] * (1 if dataset_name == "mnist" else 3), std=[0.5] * (1 if dataset_name == "mnist" else 3), )) transform = T.Compose(transforms_list) if dataset_name == "mnist": ds = torchvision.datasets.MNIST( root=data_root, train=train, download=True, transform=transform ) elif dataset_name == "cifar10": ds = torchvision.datasets.CIFAR10( root=data_root, train=train, download=True, transform=transform ) else: raise ValueError(f"Unknown image dataset: {dataset_name}") return DataLoader( ds, batch_size=batch_size, shuffle=train, num_workers=num_workers, pin_memory=True, drop_last=True, ) def sample_source_image(n: int, channels: int, image_size: int) -> torch.Tensor: return torch.randn(n, channels, image_size, image_size) # ============================================================ # DatasetLoader class (unified interface) # ============================================================ class DatasetLoader: def __init__(self, config: dict): self.config = config self.dataset_name = config.get("dataset", "8gaussians") self.is_image = self.dataset_name in ("mnist", "cifar10") def sample_target(self, n: int, device: str = "cpu") -> torch.Tensor: if self.is_image: # Recreate DataLoader if batch size changed (different training phases # use different batch sizes, e.g. 256 for pool building, 128 for NSF) if not hasattr(self, "_image_loader") or self._image_batch_size != n: self._image_batch_size = n self._image_loader = get_image_dataloader( self.dataset_name, batch_size=n, train=True ) self._image_iter = iter(self._image_loader) try: images, _ = next(self._image_iter) except StopIteration: self._image_iter = iter(self._image_loader) images, _ = next(self._image_iter) return images.to(device) else: return get_2d_dataset(self.dataset_name, n).to(device) def sample_source(self, n: int, device: str = "cpu") -> torch.Tensor: if self.is_image: channels = self.config.get("in_channels", 1) image_size = self.config.get("image_size", 28) return sample_source_image(n, channels, image_size).to(device) else: dim = self.config.get("model", {}).get("input_dim", 2) return sample_source_2d(n, dim).to(device) def get_test_samples(self, n: int, device: str = "cpu") -> torch.Tensor: if self.is_image: loader = get_image_dataloader( self.dataset_name, batch_size=n, train=False ) images, _ = next(iter(loader)) return images.to(device) else: return get_2d_dataset(self.dataset_name, n).to(device) @property def data_dim(self) -> int: if self.is_image: c = self.config.get("in_channels", 1) s = self.config.get("image_size", 28) return c * s * s else: return self.config.get("model", {}).get("input_dim", 2)