| """dataset_loader.py — Data loading for NSGF/NSGF++ experiments. |
| |
| Handles: |
| - 2D synthetic datasets (8gaussians, moons, scurve, checkerboard) |
| - MNIST / CIFAR-10 for image experiments |
| - Source distributions (standard Gaussian) |
| |
| Reference: arXiv:2401.14069, Appendix E.1 and E.2 |
| """ |
|
|
| import math |
| import numpy as np |
| import torch |
| from torch.utils.data import DataLoader, TensorDataset |
| from sklearn.datasets import make_moons, make_s_curve |
|
|
|
|
| |
| |
| |
|
|
| def sample_8gaussians(n: int, scale: float = 4.0, std: float = 0.5) -> torch.Tensor: |
| """8 Gaussian modes arranged in a circle.""" |
| centers = [] |
| for i in range(8): |
| angle = 2 * math.pi * i / 8 |
| centers.append((scale * math.cos(angle), scale * math.sin(angle))) |
| centers = np.array(centers) |
| idx = np.random.randint(0, 8, n) |
| data = centers[idx] + np.random.randn(n, 2) * std |
| return torch.FloatTensor(data) |
|
|
|
|
| def sample_moons(n: int, noise: float = 0.05) -> torch.Tensor: |
| """Two interleaving half-circles (scikit-learn moons).""" |
| data, _ = make_moons(n_samples=n, noise=noise) |
| data = data * 3.0 - np.array([1.0, 0.0]) |
| return torch.FloatTensor(data) |
|
|
|
|
| def sample_scurve(n: int, noise: float = 0.0) -> torch.Tensor: |
| """S-curve projected to 2D.""" |
| data, _ = make_s_curve(n_samples=n, noise=noise) |
| data = data[:, [0, 2]] * 3.0 |
| return torch.FloatTensor(data) |
|
|
|
|
| def sample_checkerboard(n: int) -> torch.Tensor: |
| """4x4 checkerboard pattern.""" |
| x1 = np.random.rand(n) * 4 - 2 |
| x2_ = np.random.rand(n) - np.random.randint(0, 2, n) * 2 |
| x2 = x2_ + (np.floor(x1) % 2) |
| data = np.column_stack([x1, x2]) * 2 |
| return torch.FloatTensor(data) |
|
|
|
|
| def sample_8gaussians_moons(n: int) -> torch.Tensor: |
| """Mixture: half from 8gaussians, half from moons.""" |
| n1 = n // 2 |
| n2 = n - n1 |
| g = sample_8gaussians(n1) |
| m = sample_moons(n2) |
| data = torch.cat([g, m], dim=0) |
| perm = torch.randperm(n) |
| return data[perm] |
|
|
|
|
| DATASET_2D = { |
| "8gaussians": sample_8gaussians, |
| "moons": sample_moons, |
| "scurve": sample_scurve, |
| "checkerboard": sample_checkerboard, |
| "8gaussians_moons": sample_8gaussians_moons, |
| } |
|
|
|
|
| def get_2d_dataset(name: str, n: int) -> torch.Tensor: |
| if name not in DATASET_2D: |
| raise ValueError(f"Unknown 2D dataset: {name}. Available: {list(DATASET_2D.keys())}") |
| return DATASET_2D[name](n) |
|
|
|
|
| def sample_source_2d(n: int, dim: int = 2) -> torch.Tensor: |
| return torch.randn(n, dim) |
|
|
|
|
| |
| |
| |
|
|
| def get_image_dataloader( |
| dataset_name: str, |
| batch_size: int, |
| train: bool = True, |
| data_root: str = "./data", |
| num_workers: int = 2, |
| normalize_range: tuple = (-1.0, 1.0), |
| ) -> DataLoader: |
| import torchvision |
| import torchvision.transforms as T |
|
|
| lo, hi = normalize_range |
| transforms_list = [T.ToTensor()] |
| transforms_list.append(T.Normalize( |
| mean=[0.5] * (1 if dataset_name == "mnist" else 3), |
| std=[0.5] * (1 if dataset_name == "mnist" else 3), |
| )) |
| transform = T.Compose(transforms_list) |
|
|
| if dataset_name == "mnist": |
| ds = torchvision.datasets.MNIST( |
| root=data_root, train=train, download=True, transform=transform |
| ) |
| elif dataset_name == "cifar10": |
| ds = torchvision.datasets.CIFAR10( |
| root=data_root, train=train, download=True, transform=transform |
| ) |
| else: |
| raise ValueError(f"Unknown image dataset: {dataset_name}") |
|
|
| return DataLoader( |
| ds, batch_size=batch_size, shuffle=train, |
| num_workers=num_workers, pin_memory=True, drop_last=True, |
| ) |
|
|
|
|
| def sample_source_image(n: int, channels: int, image_size: int) -> torch.Tensor: |
| return torch.randn(n, channels, image_size, image_size) |
|
|
|
|
| |
| |
| |
|
|
| class DatasetLoader: |
| def __init__(self, config: dict): |
| self.config = config |
| self.dataset_name = config.get("dataset", "8gaussians") |
| self.is_image = self.dataset_name in ("mnist", "cifar10") |
|
|
| def sample_target(self, n: int, device: str = "cpu") -> torch.Tensor: |
| if self.is_image: |
| |
| |
| if not hasattr(self, "_image_loader") or self._image_batch_size != n: |
| self._image_batch_size = n |
| self._image_loader = get_image_dataloader( |
| self.dataset_name, batch_size=n, train=True |
| ) |
| self._image_iter = iter(self._image_loader) |
| try: |
| images, _ = next(self._image_iter) |
| except StopIteration: |
| self._image_iter = iter(self._image_loader) |
| images, _ = next(self._image_iter) |
| return images.to(device) |
| else: |
| return get_2d_dataset(self.dataset_name, n).to(device) |
|
|
| def sample_source(self, n: int, device: str = "cpu") -> torch.Tensor: |
| if self.is_image: |
| channels = self.config.get("in_channels", 1) |
| image_size = self.config.get("image_size", 28) |
| return sample_source_image(n, channels, image_size).to(device) |
| else: |
| dim = self.config.get("model", {}).get("input_dim", 2) |
| return sample_source_2d(n, dim).to(device) |
|
|
| def get_test_samples(self, n: int, device: str = "cpu") -> torch.Tensor: |
| if self.is_image: |
| loader = get_image_dataloader( |
| self.dataset_name, batch_size=n, train=False |
| ) |
| images, _ = next(iter(loader)) |
| return images.to(device) |
| else: |
| return get_2d_dataset(self.dataset_name, n).to(device) |
|
|
| @property |
| def data_dim(self) -> int: |
| if self.is_image: |
| c = self.config.get("in_channels", 1) |
| s = self.config.get("image_size", 28) |
| return c * s * s |
| else: |
| return self.config.get("model", {}).get("input_dim", 2) |
|
|