nsgf-plusplus / dataset_loader.py
rogermt's picture
Fix DataLoader batch size mismatch across training phases + --train-iters now overrides all phases
9e3fccc verified
"""dataset_loader.py — Data loading for NSGF/NSGF++ experiments.
Handles:
- 2D synthetic datasets (8gaussians, moons, scurve, checkerboard)
- MNIST / CIFAR-10 for image experiments
- Source distributions (standard Gaussian)
Reference: arXiv:2401.14069, Appendix E.1 and E.2
"""
import math
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import make_moons, make_s_curve
# ============================================================
# 2D Synthetic Datasets (following Tong et al. 2023 / Grathwohl et al. 2018)
# ============================================================
def sample_8gaussians(n: int, scale: float = 4.0, std: float = 0.5) -> torch.Tensor:
"""8 Gaussian modes arranged in a circle."""
centers = []
for i in range(8):
angle = 2 * math.pi * i / 8
centers.append((scale * math.cos(angle), scale * math.sin(angle)))
centers = np.array(centers)
idx = np.random.randint(0, 8, n)
data = centers[idx] + np.random.randn(n, 2) * std
return torch.FloatTensor(data)
def sample_moons(n: int, noise: float = 0.05) -> torch.Tensor:
"""Two interleaving half-circles (scikit-learn moons)."""
data, _ = make_moons(n_samples=n, noise=noise)
data = data * 3.0 - np.array([1.0, 0.0])
return torch.FloatTensor(data)
def sample_scurve(n: int, noise: float = 0.0) -> torch.Tensor:
"""S-curve projected to 2D."""
data, _ = make_s_curve(n_samples=n, noise=noise)
data = data[:, [0, 2]] * 3.0
return torch.FloatTensor(data)
def sample_checkerboard(n: int) -> torch.Tensor:
"""4x4 checkerboard pattern."""
x1 = np.random.rand(n) * 4 - 2
x2_ = np.random.rand(n) - np.random.randint(0, 2, n) * 2
x2 = x2_ + (np.floor(x1) % 2)
data = np.column_stack([x1, x2]) * 2
return torch.FloatTensor(data)
def sample_8gaussians_moons(n: int) -> torch.Tensor:
"""Mixture: half from 8gaussians, half from moons."""
n1 = n // 2
n2 = n - n1
g = sample_8gaussians(n1)
m = sample_moons(n2)
data = torch.cat([g, m], dim=0)
perm = torch.randperm(n)
return data[perm]
DATASET_2D = {
"8gaussians": sample_8gaussians,
"moons": sample_moons,
"scurve": sample_scurve,
"checkerboard": sample_checkerboard,
"8gaussians_moons": sample_8gaussians_moons,
}
def get_2d_dataset(name: str, n: int) -> torch.Tensor:
if name not in DATASET_2D:
raise ValueError(f"Unknown 2D dataset: {name}. Available: {list(DATASET_2D.keys())}")
return DATASET_2D[name](n)
def sample_source_2d(n: int, dim: int = 2) -> torch.Tensor:
return torch.randn(n, dim)
# ============================================================
# Image Datasets (MNIST, CIFAR-10)
# ============================================================
def get_image_dataloader(
dataset_name: str,
batch_size: int,
train: bool = True,
data_root: str = "./data",
num_workers: int = 2,
normalize_range: tuple = (-1.0, 1.0),
) -> DataLoader:
import torchvision
import torchvision.transforms as T
lo, hi = normalize_range
transforms_list = [T.ToTensor()]
transforms_list.append(T.Normalize(
mean=[0.5] * (1 if dataset_name == "mnist" else 3),
std=[0.5] * (1 if dataset_name == "mnist" else 3),
))
transform = T.Compose(transforms_list)
if dataset_name == "mnist":
ds = torchvision.datasets.MNIST(
root=data_root, train=train, download=True, transform=transform
)
elif dataset_name == "cifar10":
ds = torchvision.datasets.CIFAR10(
root=data_root, train=train, download=True, transform=transform
)
else:
raise ValueError(f"Unknown image dataset: {dataset_name}")
return DataLoader(
ds, batch_size=batch_size, shuffle=train,
num_workers=num_workers, pin_memory=True, drop_last=True,
)
def sample_source_image(n: int, channels: int, image_size: int) -> torch.Tensor:
return torch.randn(n, channels, image_size, image_size)
# ============================================================
# DatasetLoader class (unified interface)
# ============================================================
class DatasetLoader:
def __init__(self, config: dict):
self.config = config
self.dataset_name = config.get("dataset", "8gaussians")
self.is_image = self.dataset_name in ("mnist", "cifar10")
def sample_target(self, n: int, device: str = "cpu") -> torch.Tensor:
if self.is_image:
# Recreate DataLoader if batch size changed (different training phases
# use different batch sizes, e.g. 256 for pool building, 128 for NSF)
if not hasattr(self, "_image_loader") or self._image_batch_size != n:
self._image_batch_size = n
self._image_loader = get_image_dataloader(
self.dataset_name, batch_size=n, train=True
)
self._image_iter = iter(self._image_loader)
try:
images, _ = next(self._image_iter)
except StopIteration:
self._image_iter = iter(self._image_loader)
images, _ = next(self._image_iter)
return images.to(device)
else:
return get_2d_dataset(self.dataset_name, n).to(device)
def sample_source(self, n: int, device: str = "cpu") -> torch.Tensor:
if self.is_image:
channels = self.config.get("in_channels", 1)
image_size = self.config.get("image_size", 28)
return sample_source_image(n, channels, image_size).to(device)
else:
dim = self.config.get("model", {}).get("input_dim", 2)
return sample_source_2d(n, dim).to(device)
def get_test_samples(self, n: int, device: str = "cpu") -> torch.Tensor:
if self.is_image:
loader = get_image_dataloader(
self.dataset_name, batch_size=n, train=False
)
images, _ = next(iter(loader))
return images.to(device)
else:
return get_2d_dataset(self.dataset_name, n).to(device)
@property
def data_dim(self) -> int:
if self.is_image:
c = self.config.get("in_channels", 1)
s = self.config.get("image_size", 28)
return c * s * s
else:
return self.config.get("model", {}).get("input_dim", 2)