| |
| """ |
| End-to-end parameter inference for the conditional DDPM stack: |
| training (noise-prediction MSE / ELBO surrogate), checkpointing, conditional |
| sampling, evaluation-style metrics, and optional **VLB-based cosmological |
| parameter constraints** following Mudur et al. (2023). |
| |
| Reference (parameter inference via conditional diffusion VLB): |
| Mudur, Cuesta-Lazaro & Finkbeiner, "Cosmological Field Emulation and |
| Parameter Inference with Diffusion Models", arXiv:2312.07534 (2023). |
| https://arxiv.org/abs/2312.07534 |
| |
| They train a DDPM (Ho et al. 2020) on log density fields conditioned on |
| (Omega_m, sigma_8), then evaluate VLB terms L_t(x_0 | theta_eval) on a |
| grid in parameter space. The dominant term is L_0 = -log p_phi(x_0 | x_1, theta) |
| with x_1 ~ q(x_1|x_0). They form -2 Delta ln L_hat ~ 2(L_0 - min L_0) and |
| map marginals to approximate posteriors (68% intervals on a grid). |
| |
| This script implements the **L_0 approximation** (their primary reported setup) |
| using the existing GaussianDiffusion reverse mean/variance at timestep index t=1. |
| Full multi-t VLB sums are left as a documented extension. |
| |
| Note: train_conditional.py exposes hyperparameters via argparse (no separate |
| Config dataclass). This script mirrors those fields and uses the same training |
| utilities (EMA, AMP, grad clip inside train_epoch). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import logging |
| import math |
| import os |
| import random |
| import sys |
| import time |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import matplotlib |
|
|
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
| import torch.optim as optim |
| from torch.utils.data import DataLoader |
|
|
| from dataset_conditional import get_conditional_dataloaders |
| from diffusion_conditional import ConditionalDiffusionModel, GaussianDiffusion |
| from evaluate_conditional import ( |
| build_model, |
| calculate_pdf_batch, |
| calculate_power_spectrum_batch, |
| from_model_output, |
| load_checkpoint, |
| load_label_stats, |
| load_split, |
| load_training_config, |
| prepare_labels_for_model, |
| ) |
| from train_conditional import ( |
| EMA, |
| save_checkpoint, |
| save_training_args, |
| train_epoch, |
| validate, |
| ) |
| from unet_conditional import ConditionalUNet |
|
|
|
|
| def _setup_logging(log_path: Optional[Path] = None) -> logging.Logger: |
| log = logging.getLogger("parameter_inference_conditional") |
| log.handlers.clear() |
| log.setLevel(logging.INFO) |
| fmt = logging.Formatter("%(asctime)s | %(levelname)s | %(message)s") |
| sh = logging.StreamHandler(sys.stdout) |
| sh.setFormatter(fmt) |
| log.addHandler(sh) |
| if log_path is not None: |
| fh = logging.FileHandler(log_path, encoding="utf-8") |
| fh.setFormatter(fmt) |
| log.addHandler(fh) |
| return log |
|
|
|
|
| def set_seed(seed: int) -> None: |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
|
|
| def _infer_spatial_size(loader: DataLoader) -> Tuple[int, int]: |
| img0, _ = loader.dataset[0] |
| if img0.dim() == 3: |
| _, h, w = img0.shape |
| else: |
| h, w = img0.shape[-2], img0.shape[-1] |
| return int(h), int(w) |
|
|
|
|
| def save_conditioned_sample_grid( |
| model: ConditionalDiffusionModel, |
| diffusion: GaussianDiffusion, |
| labels: torch.Tensor, |
| device: torch.device, |
| save_path: Path, |
| *, |
| channels: int, |
| height: int, |
| width: int, |
| ema: Optional[EMA], |
| use_ddim: bool, |
| ddim_steps: int, |
| title: str = "Conditional samples", |
| ) -> None: |
| """Save a grid of DDPM/DDIM samples conditioned on label vectors (same idea as train_conditional.sample_images, spatial size from data).""" |
| if ema is not None: |
| ema.apply_shadow() |
| unet = model.unet |
| unet.eval() |
| labels = labels.to(device) |
| n_samples = labels.shape[0] |
| with torch.no_grad(): |
| samples = diffusion.sample( |
| model, |
| labels=labels, |
| channels=channels, |
| height=height, |
| width=width, |
| device=device, |
| progress=False, |
| use_ddim=use_ddim, |
| ddim_steps=ddim_steps, |
| eta=0.0, |
| ) |
| if ema is not None: |
| ema.restore() |
|
|
| n_cols = min(n_samples, 4) |
| n_rows = (n_samples + n_cols - 1) // n_cols |
| fig, axes = plt.subplots(n_rows, n_cols, figsize=(4.5 * n_cols, 4.5 * n_rows)) |
| if n_rows == 1 and n_cols == 1: |
| axes = np.array([[axes]]) |
| elif n_rows == 1: |
| axes = axes[np.newaxis, :] |
| elif n_cols == 1: |
| axes = axes[:, np.newaxis] |
| for i in range(n_rows * n_cols): |
| ax = axes[i // n_cols, i % n_cols] |
| if i < n_samples: |
| img = samples[i, 0].cpu().numpy() |
| label_vals = labels[i].cpu().tolist() |
| label_str = ", ".join(f"{v:.3f}" for v in label_vals) |
| ax.imshow(img, cmap="gray", vmin=-1, vmax=1) |
| ax.set_title(label_str, fontsize=10) |
| ax.axis("off") |
| plt.suptitle(title, fontsize=14) |
| plt.tight_layout() |
| save_path.parent.mkdir(parents=True, exist_ok=True) |
| plt.savefig(save_path, dpi=150, bbox_inches="tight") |
| plt.close() |
| logging.getLogger("parameter_inference_conditional").info("Saved sample grid to %s", save_path) |
|
|
|
|
| def _final_metrics_log( |
| real_np: np.ndarray, |
| gen_np: np.ndarray, |
| log: logging.Logger, |
| ) -> Dict[str, float]: |
| """Compute lightweight distributional metrics (PDF / P(k) curve L2 on binned means).""" |
| _, mean_pdf_r, _ = calculate_pdf_batch(real_np) |
| bc, mean_pdf_g, _ = calculate_pdf_batch(gen_np) |
| pdf_mse = float(np.mean((mean_pdf_r - mean_pdf_g) ** 2)) |
|
|
| dk, mean_pk_r, _ = calculate_power_spectrum_batch(real_np) |
| _, mean_pk_g, _ = calculate_power_spectrum_batch(gen_np) |
| k_min = 1 |
| pk_mse = float(np.mean((mean_pk_r[k_min:] - mean_pk_g[k_min:]) ** 2)) |
|
|
| log.info("Final metric | PDF mean MSE (density bins): %.6e", pdf_mse) |
| log.info("Final metric | P(k) mean MSE (k>0 bins): %.6e", pk_mse) |
| return { |
| "pdf_mean_mse": pdf_mse, |
| "pk_mean_mse": pk_mse, |
| "pdf_bin_centers": float(bc.size), |
| "pk_bins": float(dk.size), |
| } |
|
|
|
|
| |
|
|
| _LOG2PI = math.log(2.0 * math.pi) |
|
|
|
|
| def _gaussian_nll_spatial_sum(x: torch.Tensor, mean: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor: |
| """Per-batch-element NLL for diagonal Gaussian; x, mean same shape; log_var broadcastable.""" |
| while log_var.dim() < x.dim(): |
| log_var = log_var.unsqueeze(-1) |
| inv = torch.exp(-log_var) |
| nll_pix = 0.5 * ((x - mean) ** 2 * inv + log_var + _LOG2PI) |
| return nll_pix.view(nll_pix.shape[0], -1).sum(dim=1) |
|
|
|
|
| @torch.no_grad() |
| def estimate_l0_nll_batch( |
| model: ConditionalDiffusionModel, |
| diffusion: GaussianDiffusion, |
| x0: torch.Tensor, |
| labels_norm: torch.Tensor, |
| *, |
| n_seeds: int, |
| base_seed: int, |
| ) -> torch.Tensor: |
| """ |
| Monte-Carlo average of L_0 = -log p_theta(x_0 | x_1, theta) with |
| x_1 ~ q(x_1 | x_0) at diffusion index t=1 (lightly noised latent). |
| """ |
| device = x0.device |
| b = x0.shape[0] |
| if diffusion.timesteps < 3: |
| raise ValueError("VLB L0 requires diffusion.timesteps >= 3 (need t=1).") |
| t1 = torch.ones(b, device=device, dtype=torch.long) |
| acc = torch.zeros(b, device=device) |
| model.eval() |
| for s in range(n_seeds): |
| torch.manual_seed(int(base_seed + s)) |
| if device.type == "cuda": |
| torch.cuda.manual_seed_all(int(base_seed + s)) |
| noise = torch.randn(x0.shape, device=device, dtype=x0.dtype) |
| x1 = diffusion.q_sample(x0, t1, noise=noise) |
| mean, _pv, log_var, _ = diffusion.p_mean_variance(model, x1, t1, labels_norm, clip_denoised=True) |
| acc += _gaussian_nll_spatial_sum(x0, mean, log_var) |
| return acc / float(n_seeds) |
|
|
|
|
| def _build_theta_grid( |
| theta_true: np.ndarray, |
| half_width: float, |
| prior_lo: np.ndarray, |
| prior_hi: np.ndarray, |
| n_per_dim: int, |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
| """2D grid in *physical* label space (same units as .npy labels), CAMELS-style bounds.""" |
| g0 = np.linspace( |
| max(theta_true[0] - half_width, prior_lo[0]), |
| min(theta_true[0] + half_width, prior_hi[0]), |
| n_per_dim, |
| dtype=np.float64, |
| ) |
| g1 = np.linspace( |
| max(theta_true[1] - half_width, prior_lo[1]), |
| min(theta_true[1] + half_width, prior_hi[1]), |
| n_per_dim, |
| dtype=np.float64, |
| ) |
| G0, G1 = np.meshgrid(g0, g1, indexing="ij") |
| pts = np.stack([G0.ravel(), G1.ravel()], axis=1) |
| return pts, g0, g1 |
|
|
|
|
| def _delta_chi2_contour_levels_2d() -> List[float]: |
| """Approximate Delta chi^2 thresholds for 68%, 95%, 99.7% (2 dof), Mudur-style contours.""" |
| return [2.30, 5.99, 11.82] |
|
|
|
|
| def _shortest_mass_interval(x: np.ndarray, w: np.ndarray, mass: float = 0.68) -> Tuple[float, float]: |
| """Shortest interval on sorted x containing `mass` of normalized weights w.""" |
| order = np.argsort(x) |
| xs = x[order] |
| ws = w[order].astype(np.float64) |
| ws = ws / (ws.sum() + 1e-30) |
| c = np.concatenate([[0.0], np.cumsum(ws)]) |
| n = len(xs) |
| best_lo, best_hi = float(xs[0]), float(xs[-1]) |
| best_w = float("inf") |
| for i in range(n): |
| for j in range(i + 1, n + 1): |
| if c[j] - c[i] >= mass - 1e-9: |
| lo, hi = float(xs[i]), float(xs[j - 1]) |
| if hi - lo < best_w: |
| best_w = hi - lo |
| best_lo, best_hi = lo, hi |
| break |
| return best_lo, best_hi |
|
|
|
|
| def _vlb_posterior_summaries( |
| L0: np.ndarray, |
| g0: np.ndarray, |
| g1: np.ndarray, |
| ) -> Dict[str, Any]: |
| """Convert L0 grid to unnormalized likelihood exp(-(L0-min)), marginals, MAP, 68% intervals.""" |
| dchi2 = 2.0 * (L0 - L0.min()) |
| log_like = -0.5 * dchi2 |
| w = np.exp(log_like - log_like.max()) |
| w = w / (w.sum() + 1e-30) |
| n0, n1 = len(g0), len(g1) |
| W = w.reshape(n0, n1) |
| m0 = W.sum(axis=1) |
| m1 = W.sum(axis=0) |
| m0 = m0 / (m0.sum() + 1e-30) |
| m1 = m1 / (m1.sum() + 1e-30) |
| map_i, map_j = np.unravel_index(int(np.argmax(W)), W.shape) |
| theta_map = (float(g0[map_i]), float(g1[map_j])) |
| int0 = _shortest_mass_interval(g0, m0, 0.68) |
| int1 = _shortest_mass_interval(g1, m1, 0.68) |
| return { |
| "delta_chi2": dchi2.reshape(n0, n1).tolist(), |
| "theta_map_omega_m": theta_map[0], |
| "theta_map_sigma8": theta_map[1], |
| "marginal_68_omega_m": list(int0), |
| "marginal_68_sigma8": list(int1), |
| } |
|
|
|
|
| def save_vlb_corner_figure( |
| g0: np.ndarray, |
| g1: np.ndarray, |
| L0: np.ndarray, |
| theta_true: np.ndarray, |
| out_path: Path, |
| *, |
| names: Tuple[str, str] = (r"$\Omega_{\rm m}$", r"$\sigma_8$"), |
| ) -> None: |
| """2D contours of Delta = 2(L0 - min L0) with truth cross; marginals via KDE-free histogram of grid.""" |
| n0, n1 = len(g0), len(g1) |
| D = (2.0 * (L0 - L0.min())).reshape(n0, n1) |
| |
| G0_2d, G1_2d = np.meshgrid(g0, g1, indexing="ij") |
|
|
| fig = plt.figure(figsize=(7.0, 6.8)) |
| from matplotlib.gridspec import GridSpec |
|
|
| gs = GridSpec(2, 2, figure=fig, width_ratios=[4, 1.1], height_ratios=[1, 4], wspace=0.12, hspace=0.12) |
| ax_j = fig.add_subplot(gs[1, 0]) |
| ax_mx = fig.add_subplot(gs[0, 0], sharex=ax_j) |
| ax_my = fig.add_subplot(gs[1, 1], sharey=ax_j) |
| ax_mx.tick_params(labelleft=False, labelbottom=False) |
| ax_my.tick_params(labelleft=False, labelbottom=False) |
|
|
| cf = ax_j.contourf(G1_2d, G0_2d, D, levels=28, cmap="Greys", alpha=0.9) |
| for lev in _delta_chi2_contour_levels_2d(): |
| ax_j.contour(G1_2d, G0_2d, D, levels=[lev], colors="C0", linewidths=1.2) |
| ax_j.axhline(theta_true[0], color="0.35", lw=0.8, ls="--") |
| ax_j.axvline(theta_true[1], color="0.35", lw=0.8, ls="--") |
| ax_j.scatter([theta_true[1]], [theta_true[0]], marker="x", s=80, c="crimson", zorder=9, linewidths=2) |
| ax_j.set_xlabel(names[1]) |
| ax_j.set_ylabel(names[0]) |
| for lbl in ax_j.get_xticklabels(): |
| lbl.set_rotation(45) |
| lbl.set_ha("right") |
| fig.colorbar(cf, ax=ax_j, fraction=0.046, pad=0.02, label=r"$2\,[L_0 - \min L_0]$ (Mudur et al.\ proxy)") |
|
|
| W = np.exp(-0.5 * (L0 - L0.min())) |
| W = W.reshape(n0, n1) |
| m_omega = W.sum(axis=1) |
| m_sigma = W.sum(axis=0) |
| ax_mx.plot(g1, m_sigma / (m_sigma.max() + 1e-30), color="0.2", lw=1.5) |
| ax_my.plot(m_omega / (m_omega.max() + 1e-30), g0, color="0.2", lw=1.5) |
|
|
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| plt.savefig(out_path, dpi=180, bbox_inches="tight", facecolor="white") |
| plt.close() |
|
|
|
|
| def numpy_field_to_x0_tensor(img_01: np.ndarray, device: torch.device) -> torch.Tensor: |
| """[H,W] or [1,H,W] float in [0,1] -> [1,1,H,W] in [-1,1] as used in training.""" |
| t = torch.from_numpy(np.asarray(img_01, dtype=np.float32)) |
| if t.dim() == 2: |
| t = t.unsqueeze(0) |
| t = t * 2.0 - 1.0 |
| return t.unsqueeze(0).to(device) |
|
|
|
|
| def run_vlb_parameter_inference( |
| args: argparse.Namespace, |
| log: logging.Logger, |
| *, |
| output_dir: Optional[Path] = None, |
| checkpoint_path: Optional[str] = None, |
| training_args_path: Optional[str] = None, |
| ) -> None: |
| """ |
| Mudur et al. (2023) style grid evaluation of L_0 on held-out fields. |
| """ |
| device = torch.device(args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu")) |
| log.info("VLB inference | device=%s", device) |
|
|
| ta = training_args_path or args.training_args |
| ck = checkpoint_path or args.checkpoint |
| if ta is None or not os.path.isfile(str(ta)): |
| raise FileNotFoundError("VLB mode requires --training_args (args.json from training).") |
| if ck is None or not os.path.isfile(str(ck)): |
| raise FileNotFoundError("VLB mode requires --checkpoint.") |
|
|
| config = load_training_config(str(ta)) |
| model = build_model(config, device) |
| load_checkpoint(model, str(ck), device) |
| diffusion = model.diffusion |
|
|
| data_dir = Path(args.data_dir) |
| label_mean, label_std = load_label_stats(data_dir) |
| images, labels_phys = load_split(data_dir, args.vlb_split) |
| n_fields = min(args.vlb_n_fields, len(images)) |
| rng = np.random.default_rng(args.seed) |
| if n_fields < len(images): |
| pick = rng.choice(len(images), size=n_fields, replace=False) |
| else: |
| pick = np.arange(n_fields) |
|
|
| prior_lo = np.array([args.vlb_prior_omega_m[0], args.vlb_prior_sigma8[0]], dtype=np.float64) |
| prior_hi = np.array([args.vlb_prior_omega_m[1], args.vlb_prior_sigma8[1]], dtype=np.float64) |
|
|
| out_root = Path(output_dir or args.vlb_output_dir) |
| out_root.mkdir(parents=True, exist_ok=True) |
| all_rows: List[Dict[str, Any]] = [] |
|
|
| for k, idx in enumerate(pick): |
| x0 = numpy_field_to_x0_tensor(images[idx], device) |
| truth = labels_phys[idx].astype(np.float64) |
| grid_pts, g0, g1 = _build_theta_grid(truth, args.vlb_half_width, prior_lo, prior_hi, args.vlb_n_grid) |
| n_pts = grid_pts.shape[0] |
| L0_accum = np.zeros(n_pts, dtype=np.float64) |
| for start in range(0, n_pts, args.vlb_chunk_size): |
| end = min(start + args.vlb_chunk_size, n_pts) |
| chunk = grid_pts[start:end] |
| lt = prepare_labels_for_model(chunk, label_mean, label_std).to(device) |
| xrep = x0.expand(end - start, -1, -1, -1) |
| L0_b = estimate_l0_nll_batch( |
| model, |
| diffusion, |
| xrep, |
| lt, |
| n_seeds=args.vlb_l0_seeds, |
| base_seed=args.seed + k * 10007 + start, |
| ) |
| L0_accum[start:end] = L0_b.detach().cpu().numpy() |
|
|
| summ = _vlb_posterior_summaries(L0_accum, g0, g1) |
| summ.update( |
| { |
| "field_index": int(idx), |
| "theta_true_omega_m": float(truth[0]), |
| "theta_true_sigma8": float(truth[1]), |
| } |
| ) |
| all_rows.append(summ) |
| fig_path = out_root / f"vlb_corner_field_{k}_idx{idx}.png" |
| save_vlb_corner_figure(g0, g1, L0_accum, truth, fig_path) |
| log.info( |
| "VLB field %d | MAP (Om,s8)=(%.4f,%.4f) true=(%.4f,%.4f) | 68%% marg Om %s s8 %s", |
| k, |
| summ["theta_map_omega_m"], |
| summ["theta_map_sigma8"], |
| truth[0], |
| truth[1], |
| summ["marginal_68_omega_m"], |
| summ["marginal_68_sigma8"], |
| ) |
|
|
| with open(out_root / "vlb_inference_summary.json", "w", encoding="utf-8") as f: |
| json.dump(all_rows, f, indent=2) |
| log.info("Wrote VLB summary to %s", out_root / "vlb_inference_summary.json") |
|
|
|
|
| def run_training(args: argparse.Namespace, log: logging.Logger) -> str: |
| device = torch.device(args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu")) |
| log.info("Device: %s", device) |
|
|
| use_amp = bool(args.use_amp) and device.type == "cuda" |
| scaler = torch.amp.GradScaler("cuda") if use_amp else None |
| if use_amp: |
| log.info("Mixed precision (torch.amp.GradScaler + autocast in train_epoch) enabled.") |
|
|
| timestamp = time.strftime("%Y%m%d_%H%M%S") |
| output_dir = f"{args.output_dir}_{timestamp}" |
| os.makedirs(output_dir, exist_ok=True) |
| os.makedirs(os.path.join(output_dir, "checkpoints"), exist_ok=True) |
| os.makedirs(os.path.join(output_dir, "samples"), exist_ok=True) |
|
|
| log_path = Path(output_dir) / "training.log" |
| _setup_logging(log_path) |
|
|
| save_training_args(args, output_dir) |
|
|
| pin_memory = bool(args.pin_memory) and device.type == "cuda" |
| log.info("Loading dataloaders from %s (pin_memory=%s)", args.data_dir, pin_memory) |
| train_loader, val_loader, test_loader = get_conditional_dataloaders( |
| data_dir=args.data_dir, |
| batch_size=args.batch_size, |
| num_workers=args.num_workers, |
| pin_memory=pin_memory, |
| normalize_labels=args.normalize_labels, |
| ) |
| _, test_labels_tensor = next(iter(test_loader)) |
| h, w = _infer_spatial_size(train_loader) |
| channels = train_loader.dataset[0][0].shape[0] if train_loader.dataset[0][0].dim() == 3 else 1 |
| log.info("Spatial size HxW=%dx%d, channels=%d", h, w, channels) |
|
|
| log.info("Building ConditionalUNet + GaussianDiffusion (T=%d, schedule=%s)", args.timesteps, args.schedule_type) |
| unet = ConditionalUNet( |
| in_channels=channels, |
| out_channels=channels, |
| label_dim=args.label_dim, |
| base_channels=args.base_channels, |
| channel_multipliers=args.channel_multipliers, |
| attention_levels=args.attention_levels, |
| dropout=args.dropout, |
| ) |
| diffusion = GaussianDiffusion( |
| timesteps=args.timesteps, |
| beta_start=args.beta_start, |
| beta_end=args.beta_end, |
| schedule_type=args.schedule_type, |
| ) |
| model = ConditionalDiffusionModel(unet, diffusion).to(device) |
| n_params = sum(p.numel() for p in model.parameters()) |
| log.info("Trainable parameters: %s", f"{n_params:,}") |
|
|
| optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01) |
| ema = EMA(model, decay=args.ema_decay) |
| scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) |
|
|
| start_epoch = 0 |
| best_val_loss = float("inf") |
| last_improvement_epoch = -1 |
| if args.resume: |
| log.info("Resuming from %s", args.resume) |
| checkpoint = torch.load(args.resume, map_location=device, weights_only=False) |
| model.load_state_dict(checkpoint["model_state_dict"]) |
| optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) |
| if "ema_shadow" in checkpoint: |
| ema.shadow = checkpoint["ema_shadow"] |
| if "scheduler_state_dict" in checkpoint: |
| scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) |
| start_epoch = int(checkpoint["epoch"]) + 1 |
| best_val_loss = float(checkpoint.get("loss", float("inf"))) |
| last_improvement_epoch = int(checkpoint.get("last_improvement_epoch", -1)) |
|
|
| losses_train: list[float] = [] |
| losses_val: list[float] = [] |
|
|
| for epoch in range(start_epoch, args.epochs): |
| train_loss = train_epoch( |
| model, train_loader, optimizer, device, epoch, ema=ema, use_wandb=False, scaler=scaler |
| ) |
| if ema is not None: |
| ema.apply_shadow() |
| val_loss = validate(model, val_loader, device) |
| if ema is not None: |
| ema.restore() |
|
|
| losses_train.append(train_loss) |
| losses_val.append(val_loss) |
| scheduler.step() |
|
|
| log.info( |
| "Epoch %d/%d | train_loss=%.6f | val_loss=%.6f | lr=%.6e", |
| epoch + 1, |
| args.epochs, |
| train_loss, |
| val_loss, |
| optimizer.param_groups[0]["lr"], |
| ) |
|
|
| is_best = val_loss < best_val_loss |
| if is_best: |
| best_val_loss = val_loss |
| last_improvement_epoch = epoch |
|
|
| save_checkpoint( |
| model, |
| optimizer, |
| ema, |
| epoch, |
| val_loss, |
| os.path.join(output_dir, "checkpoints"), |
| is_best=is_best, |
| last_improvement_epoch=last_improvement_epoch, |
| scheduler=scheduler, |
| ) |
|
|
| if epoch - last_improvement_epoch >= args.early_stop_patience: |
| log.info("Early stopping at epoch %d", epoch + 1) |
| break |
|
|
| if (epoch + 1) % args.sample_every == 0: |
| sample_path = Path(output_dir) / "samples" / f"samples_epoch_{epoch+1}.png" |
| save_conditioned_sample_grid( |
| model, |
| diffusion, |
| test_labels_tensor[: args.n_preview_samples], |
| device, |
| sample_path, |
| channels=channels, |
| height=h, |
| width=w, |
| ema=ema, |
| use_ddim=args.use_ddim, |
| ddim_steps=args.ddim_steps, |
| title=f"Generated samples — epoch {epoch+1}", |
| ) |
|
|
| if (epoch + 1) % 5 == 0: |
| plt.figure(figsize=(10, 5)) |
| plt.plot(losses_train, label="Train") |
| plt.plot(losses_val, label="Val") |
| plt.yscale("log") |
| plt.xlabel("Epoch") |
| plt.ylabel("Loss") |
| plt.title("Training / validation noise-prediction loss") |
| plt.legend() |
| plt.grid(True, alpha=0.3) |
| plt.savefig(Path(output_dir) / "losses.png", dpi=150) |
| plt.close() |
|
|
| log.info("Training finished. Best validation loss: %.6f", best_val_loss) |
|
|
| |
| best_ckpt = Path(output_dir) / "checkpoints" / "best_model.pt" |
| if not best_ckpt.is_file(): |
| best_ckpt = Path(output_dir) / "checkpoints" / "checkpoint_latest.pt" |
| args_json = Path(output_dir) / "args.json" |
| config = load_training_config(str(args_json)) |
| eval_model = build_model(config, device) |
| load_checkpoint(eval_model, str(best_ckpt), device) |
| eval_diffusion = eval_model.diffusion |
|
|
| grid_path = Path(output_dir) / "generated_samples_conditional.png" |
| save_conditioned_sample_grid( |
| eval_model, |
| eval_diffusion, |
| test_labels_tensor[: args.n_preview_samples], |
| device, |
| grid_path, |
| channels=channels, |
| height=h, |
| width=w, |
| ema=None, |
| use_ddim=args.use_ddim, |
| ddim_steps=args.ddim_steps, |
| title="Post-training conditional samples (EMA weights if present in checkpoint)", |
| ) |
|
|
| data_dir = Path(args.data_dir) |
| try: |
| label_mean, label_std = load_label_stats(data_dir) |
| images_test, labels_test = load_split(data_dir, "test") |
| n_metric = min(args.metric_num_samples, len(images_test)) |
| idx = np.random.choice(len(images_test), n_metric, replace=False) |
| real_slice = images_test[idx] |
| labels_slice = labels_test[idx] |
| labels_t = prepare_labels_for_model(labels_slice, label_mean, label_std).to(device) |
| gen_list = [] |
| bs = min(args.metric_batch_size, n_metric) |
| for i in range(0, n_metric, bs): |
| lt = labels_t[i : i + bs] |
| with torch.no_grad(): |
| g = eval_model.sample( |
| labels=lt, |
| channels=channels, |
| height=h, |
| width=w, |
| device=device, |
| progress=False, |
| use_ddim=args.use_ddim, |
| ddim_steps=args.ddim_steps, |
| eta=0.0, |
| ) |
| gen_list.append(from_model_output(g)) |
| gen_np = np.concatenate(gen_list, axis=0) |
| metrics = _final_metrics_log(real_slice, gen_np, log) |
| with open(Path(output_dir) / "final_metrics.json", "w", encoding="utf-8") as f: |
| json.dump( |
| { |
| "best_val_loss": best_val_loss, |
| "checkpoint": str(best_ckpt), |
| **{k: v for k, v in metrics.items() if isinstance(v, (int, float))}, |
| }, |
| f, |
| indent=2, |
| ) |
| except FileNotFoundError as e: |
| log.warning("Skipping final PDF/P(k) metrics (data not found): %s", e) |
|
|
| summary_path = Path(output_dir) / "run_summary.txt" |
| with open(summary_path, "w", encoding="utf-8") as f: |
| f.write(f"output_dir: {output_dir}\n") |
| f.write(f"best_val_loss: {best_val_loss}\n") |
| f.write(f"best_checkpoint: {best_ckpt}\n") |
| f.write(f"generated_grid: {grid_path}\n") |
| log.info("Wrote run summary to %s", summary_path) |
|
|
| if getattr(args, "run_vlb_after_train", False): |
| vlb_dir = Path(output_dir) / args.vlb_output_subdir |
| log.info("Running Mudur et al. VLB grid inference (post-train) -> %s", vlb_dir) |
| run_vlb_parameter_inference( |
| args, |
| log, |
| output_dir=vlb_dir, |
| checkpoint_path=str(best_ckpt), |
| training_args_path=str(args_json), |
| ) |
|
|
| return output_dir |
|
|
|
|
| def run_inference(args: argparse.Namespace, log: logging.Logger) -> None: |
| device = torch.device(args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu")) |
| log.info("Device: %s", device) |
|
|
| checkpoint_path = args.checkpoint |
| training_args_path = args.training_args |
|
|
| if training_args_path is None or not os.path.isfile(training_args_path): |
| candidates = list(Path(".").rglob("args.json")) + list(Path(".").rglob("args.txt")) |
| if not candidates: |
| raise FileNotFoundError( |
| "Provide --training_args pointing to args.json (or args.txt) from a training run." |
| ) |
| training_args_path = str(max(candidates, key=lambda p: p.stat().st_mtime)) |
| log.info("Auto-selected training args: %s", training_args_path) |
|
|
| if checkpoint_path is None or not os.path.isfile(checkpoint_path): |
| ckpts = list(Path(".").rglob("checkpoints/best_model.pt")) |
| if not ckpts: |
| ckpts = list(Path(".").rglob("checkpoints/checkpoint_latest.pt")) |
| if not ckpts: |
| raise FileNotFoundError("Provide --checkpoint or train first (no best_model.pt found).") |
| checkpoint_path = str(max(ckpts, key=lambda p: p.stat().st_mtime)) |
| log.info("Auto-selected checkpoint: %s", checkpoint_path) |
|
|
| config = load_training_config(training_args_path) |
| model = build_model(config, device) |
| load_checkpoint(model, checkpoint_path, device) |
| diffusion = model.diffusion |
|
|
| pin_memory = bool(args.pin_memory) and device.type == "cuda" |
| _, _, test_loader = get_conditional_dataloaders( |
| data_dir=args.data_dir, |
| batch_size=args.batch_size, |
| num_workers=args.num_workers, |
| pin_memory=pin_memory, |
| normalize_labels=config.get("normalize_labels", True), |
| ) |
| _, labels_tensor = next(iter(test_loader)) |
| h, w = _infer_spatial_size(test_loader) |
| ch = test_loader.dataset[0][0].shape[0] if test_loader.dataset[0][0].dim() == 3 else 1 |
|
|
| out_path = Path(args.inference_output) / "generated_samples_conditional.png" |
| save_conditioned_sample_grid( |
| model, |
| diffusion, |
| labels_tensor[: args.n_preview_samples], |
| device, |
| out_path, |
| channels=ch, |
| height=h, |
| width=w, |
| ema=None, |
| use_ddim=args.use_ddim, |
| ddim_steps=args.ddim_steps, |
| title="Inference — conditional samples", |
| ) |
| log.info("Inference complete. Grid: %s", out_path) |
|
|
|
|
| def build_argparser() -> argparse.ArgumentParser: |
| p = argparse.ArgumentParser( |
| description="Conditional DDPM: train, sample, and VLB-based cosmo inference (Mudur et al. 2023)" |
| ) |
| p.add_argument( |
| "--mode", |
| type=str, |
| choices=["train", "inference", "vlb"], |
| required=True, |
| help="train | inference (samples) | vlb (L0 grid on held-out fields, arXiv:2312.07534)", |
| ) |
| p.add_argument("--device", type=str, default="", help="cuda | cpu (empty = auto)") |
| p.add_argument("--seed", type=int, default=42) |
|
|
| |
| p.add_argument("--label_dim", type=int, default=2, help="Conditioning vector dimension (e.g. Omega_m, sigma_8).") |
| p.add_argument("--base_channels", type=int, default=64) |
| p.add_argument("--channel_multipliers", type=int, nargs="+", default=[1, 2, 4, 8]) |
| p.add_argument("--attention_levels", type=int, nargs="+", default=[2, 3]) |
| p.add_argument("--dropout", type=float, default=0.1) |
|
|
| |
| p.add_argument("--timesteps", type=int, default=1500, help="Forward process length T (beta schedule discretization).") |
| p.add_argument("--beta_start", type=float, default=1e-4) |
| p.add_argument("--beta_end", type=float, default=0.02) |
| p.add_argument("--schedule_type", type=str, default="linear", choices=["linear", "cosine"]) |
|
|
| |
| p.add_argument("--epochs", type=int, default=100) |
| p.add_argument("--batch_size", type=int, default=8) |
| p.add_argument("--lr", type=float, default=2e-4) |
| p.add_argument("--ema_decay", type=float, default=0.9999) |
| p.add_argument("--num_workers", type=int, default=4) |
| p.add_argument("--early_stop_patience", type=int, default=30) |
| p.add_argument("--use_amp", action="store_true", default=False) |
| p.add_argument("--pin_memory", action=argparse.BooleanOptionalAction, default=True) |
|
|
| |
| p.add_argument("--data_dir", type=str, default="./data/params_2") |
| p.add_argument("--normalize_labels", action=argparse.BooleanOptionalAction, default=True) |
|
|
| |
| p.add_argument("--output_dir", type=str, default="outputs_conditional") |
| p.add_argument("--resume", type=str, default="") |
| p.add_argument("--sample_every", type=int, default=10) |
| p.add_argument("--use_ddim", action=argparse.BooleanOptionalAction, default=True) |
| p.add_argument("--ddim_steps", type=int, default=50) |
| p.add_argument("--n_preview_samples", type=int, default=8, help="Grid size for conditional previews.") |
| p.add_argument("--metric_num_samples", type=int, default=64, help="Samples for post-train PDF/P(k) metrics.") |
| p.add_argument("--metric_batch_size", type=int, default=8) |
|
|
| |
| p.add_argument("--checkpoint", type=str, default=None) |
| p.add_argument("--training_args", type=str, default=None, help="Path to args.json or args.txt from a train run.") |
| p.add_argument("--inference_output", type=str, default="inference_outputs", help="Directory for inference artifacts.") |
|
|
| |
| p.add_argument( |
| "--run_vlb_after_train", |
| action="store_true", |
| help="After training, run L0 grid parameter inference on held-out fields (writes under vlb_output_subdir).", |
| ) |
| p.add_argument("--vlb_output_subdir", type=str, default="vlb_posterior", help="Subfolder under training output_dir for VLB plots.") |
| p.add_argument("--vlb_output_dir", type=str, default="vlb_inference_out", help="Output directory when --mode vlb.") |
| p.add_argument("--vlb_split", type=str, default="test", choices=["train", "val", "test"]) |
| p.add_argument("--vlb_n_fields", type=int, default=4, help="Number of random fields to evaluate.") |
| p.add_argument("--vlb_n_grid", type=int, default=32, help="Grid points per parameter (paper uses 50; smaller is faster).") |
| p.add_argument( |
| "--vlb_half_width", |
| type=float, |
| default=0.1, |
| help="Half-width of grid in each physical parameter (paper: ±0.1 clipped to CAMELS priors).", |
| ) |
| p.add_argument( |
| "--vlb_prior_omega_m", |
| type=float, |
| nargs=2, |
| default=[0.1, 0.5], |
| metavar=("LO", "HI"), |
| help="Prior range for Omega_m (physical units, matches Mudur et al. CMD priors).", |
| ) |
| p.add_argument( |
| "--vlb_prior_sigma8", |
| type=float, |
| nargs=2, |
| default=[0.6, 1.0], |
| metavar=("LO", "HI"), |
| help="Prior range for sigma_8 (physical units).", |
| ) |
| p.add_argument("--vlb_l0_seeds", type=int, default=3, help="MC seeds for x1 ~ q(x1|x0) in L0 (cosmic variance proxy).") |
| p.add_argument("--vlb_chunk_size", type=int, default=32, help="Batch size for grid points on GPU.") |
|
|
| return p |
|
|
|
|
| def main() -> None: |
| parser = build_argparser() |
| args = parser.parse_args() |
| set_seed(args.seed) |
| log = _setup_logging() |
| log.info("parameter_inference_conditional.py | mode=%s", args.mode) |
|
|
| if args.mode == "train": |
| run_training(args, log) |
| elif args.mode == "inference": |
| os.makedirs(args.inference_output, exist_ok=True) |
| run_inference(args, log) |
| else: |
| os.makedirs(args.vlb_output_dir, exist_ok=True) |
| run_vlb_parameter_inference(args, log) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|