nsgf-plusplus / trainer.py
rogermt's picture
Add checkpointing, resume, CIFAR OOM fix, --sinkhorn-batch flag
8b62ba9 verified
"""trainer.py — Training procedures for NSGF and NSGF++.
Implements:
1. Trajectory pool construction (Phase 1: Sinkhorn gradient flow)
2. NSGF velocity field matching training
3. NSF (Neural Straight Flow) training for NSGF++
4. Phase-transition time predictor training
5. End-to-end NSGF++ training pipeline
6. Checkpointing and resume support
Reference: arXiv:2401.14069, Section 4.2–4.4, Appendix D, E
"""
import os
import logging
import torch
import torch.nn as nn
import torch.optim as optim
from typing import Optional, Dict, Any, Tuple
from dataset_loader import DatasetLoader
from sinkhorn_flow import (
SinkhornPotentialComputer, SinkhornGradientFlow, TrajectoryPool,
)
from model import VelocityMLP, VelocityUNet, PhaseTransitionPredictor
logger = logging.getLogger(__name__)
def _save_checkpoint(path: str, **kwargs):
"""Save a checkpoint dict to disk."""
os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
torch.save(kwargs, path)
logger.info(f"Checkpoint saved: {path}")
class NSGFTrainer:
"""Trainer for the Neural Sinkhorn Gradient Flow model.
Loss (Eq. 14): L(θ) = E_{(x,v,t) ~ pool} ||v_θ(x, t) - v̂(x)||²
"""
def __init__(self, model: nn.Module, data_loader: DatasetLoader,
config: dict, device: str = "cpu",
checkpoint_dir: str = "checkpoints"):
self.model = model.to(device)
self.data_loader = data_loader
self.config = config
self.device = device
self.checkpoint_dir = checkpoint_dir
sink_cfg = config.get("sinkhorn", {})
self.potential_computer = SinkhornPotentialComputer(
blur=sink_cfg.get("blur", 0.5), scaling=sink_cfg.get("scaling", 0.80),
)
self.gradient_flow = SinkhornGradientFlow(
potential_computer=self.potential_computer,
eta=sink_cfg.get("eta", 1.0), num_steps=sink_cfg.get("num_steps", 5),
)
self.pool = TrajectoryPool(max_size=5_000_000)
train_cfg = config.get("training", config.get("nsgf_training", {}))
self.num_iterations = train_cfg.get("num_iterations", 20000)
self.train_batch_size = train_cfg.get("batch_size", 256)
self.lr = train_cfg.get("learning_rate", 1e-3)
self.optimizer = optim.Adam(
self.model.parameters(), lr=self.lr,
betas=(train_cfg.get("beta1", 0.9), train_cfg.get("beta2", 0.999)),
weight_decay=train_cfg.get("weight_decay", 0.0),
)
self.checkpoint_every = config.get("checkpoint_every", 5000)
def build_trajectory_pool(self, num_batches: Optional[int] = None):
if num_batches is None:
num_batches = self.config.get("pool", {}).get("num_batches", 200)
sink_batch_size = self.config.get("sinkhorn", {}).get("batch_size", 256)
logger.info(
f"Building trajectory pool: {num_batches} batches × "
f"{sink_batch_size} samples × {self.gradient_flow.num_steps} steps"
)
for batch_idx in range(num_batches):
X0 = self.data_loader.sample_source(sink_batch_size, self.device)
Y = self.data_loader.sample_target(sink_batch_size, self.device)
_, trajectory = self.gradient_flow.run_flow(X0, Y, store_trajectory=True)
self.pool.add_trajectory(trajectory)
if (batch_idx + 1) % max(1, num_batches // 10) == 0:
logger.info(f" Pool building: {batch_idx + 1}/{num_batches}, pool size: {len(self.pool)}")
logger.info(f"Trajectory pool built. Total entries: {len(self.pool)}")
# Free GPU memory used during Sinkhorn computation
if self.device != "cpu":
torch.cuda.empty_cache()
# Pre-concatenate for O(1) sampling during training
self.pool.finalize()
logger.info("Trajectory pool finalized (pre-concatenated for fast sampling).")
def train(self, start_step: int = 0) -> Dict[str, list]:
self.model.train()
history = {"loss": [], "step": []}
logger.info(f"Starting NSGF velocity field matching: {self.num_iterations} iterations (from step {start_step})")
for step in range(start_step, self.num_iterations):
x_batch, v_batch, t_batch = self.pool.sample(self.train_batch_size, self.device)
t_normalized = t_batch / max(self.gradient_flow.num_steps, 1.0)
v_pred = self.model(x_batch, t_normalized)
loss = ((v_pred - v_batch) ** 2).mean()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if (step + 1) % 500 == 0 or step == start_step:
loss_val = loss.item()
history["loss"].append(loss_val)
history["step"].append(step + 1)
logger.info(f" Step {step + 1}/{self.num_iterations}, Loss: {loss_val:.6f}")
if (step + 1) % self.checkpoint_every == 0:
_save_checkpoint(
os.path.join(self.checkpoint_dir, "nsgf_checkpoint.pt"),
model_state=self.model.state_dict(),
optimizer_state=self.optimizer.state_dict(),
step=step + 1,
history=history,
)
logger.info("NSGF training complete.")
return history
class NSFTrainer:
"""Trainer for Neural Straight Flow (Phase 2 of NSGF++).
Straight flow: X_t = (1-t)*P_0 + t*P_1, target velocity = P_1 - P_0
"""
def __init__(self, model: nn.Module, nsgf_model: nn.Module,
data_loader: DatasetLoader, config: dict,
nsgf_num_steps: int = 5, device: str = "cpu",
checkpoint_dir: str = "checkpoints"):
self.model = model.to(device)
self.nsgf_model = nsgf_model.to(device)
self.nsgf_model.eval()
self.data_loader = data_loader
self.config = config
self.device = device
self.nsgf_num_steps = nsgf_num_steps
self.checkpoint_dir = checkpoint_dir
train_cfg = config.get("nsf_training", config.get("training", {}))
self.num_iterations = train_cfg.get("num_iterations", 100000)
self.train_batch_size = train_cfg.get("batch_size", 128)
self.lr = train_cfg.get("learning_rate", 1e-4)
self.optimizer = optim.Adam(
self.model.parameters(), lr=self.lr,
betas=(train_cfg.get("beta1", 0.9), train_cfg.get("beta2", 0.999)),
weight_decay=train_cfg.get("weight_decay", 0.0),
)
self.checkpoint_every = config.get("checkpoint_every", 5000)
@torch.no_grad()
def _generate_nsgf_samples(self, n: int) -> torch.Tensor:
X = self.data_loader.sample_source(n, self.device)
dt = 1.0 / self.nsgf_num_steps
for step in range(self.nsgf_num_steps):
t = torch.full((n,), step * dt, device=self.device)
v = self.nsgf_model(X, t)
X = X + dt * v
return X
def train(self, start_step: int = 0) -> Dict[str, list]:
self.model.train()
history = {"loss": [], "step": []}
logger.info(f"Starting NSF training: {self.num_iterations} iterations (from step {start_step})")
for step in range(start_step, self.num_iterations):
P0 = self._generate_nsgf_samples(self.train_batch_size)
P1 = self.data_loader.sample_target(self.train_batch_size, self.device)
t = torch.rand(self.train_batch_size, device=self.device)
if P0.dim() == 2:
t_expand = t.unsqueeze(-1)
else:
t_expand = t.view(-1, 1, 1, 1)
X_t = (1 - t_expand) * P0 + t_expand * P1
v_target = P1 - P0
v_pred = self.model(X_t, t)
loss = ((v_pred - v_target) ** 2).mean()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if (step + 1) % 500 == 0 or step == start_step:
loss_val = loss.item()
history["loss"].append(loss_val)
history["step"].append(step + 1)
logger.info(f" Step {step + 1}/{self.num_iterations}, Loss: {loss_val:.6f}")
if (step + 1) % self.checkpoint_every == 0:
_save_checkpoint(
os.path.join(self.checkpoint_dir, "nsf_checkpoint.pt"),
model_state=self.model.state_dict(),
optimizer_state=self.optimizer.state_dict(),
step=step + 1,
history=history,
)
logger.info("NSF training complete.")
return history
class PhaseTransitionTrainer:
"""Trainer for the phase-transition time predictor.
Loss: L(ϕ) = E_{t~U(0,1)} ||t - t_ϕ(X_t)||²
"""
def __init__(self, predictor: PhaseTransitionPredictor, nsgf_model: nn.Module,
data_loader: DatasetLoader, config: dict,
nsgf_num_steps: int = 5, device: str = "cpu",
checkpoint_dir: str = "checkpoints"):
self.predictor = predictor.to(device)
self.nsgf_model = nsgf_model.to(device)
self.nsgf_model.eval()
self.data_loader = data_loader
self.config = config
self.device = device
self.nsgf_num_steps = nsgf_num_steps
self.checkpoint_dir = checkpoint_dir
tp_cfg = config.get("time_predictor", {})
self.num_iterations = tp_cfg.get("num_iterations", 40000)
self.batch_size = tp_cfg.get("batch_size", 128)
self.lr = tp_cfg.get("learning_rate", 1e-4)
self.optimizer = optim.Adam(self.predictor.parameters(), lr=self.lr, betas=(0.9, 0.999))
self.checkpoint_every = config.get("checkpoint_every", 5000)
@torch.no_grad()
def _generate_nsgf_samples(self, n: int) -> torch.Tensor:
X = self.data_loader.sample_source(n, self.device)
dt = 1.0 / self.nsgf_num_steps
for step in range(self.nsgf_num_steps):
t = torch.full((n,), step * dt, device=self.device)
v = self.nsgf_model(X, t)
X = X + dt * v
return X
def train(self, start_step: int = 0) -> Dict[str, list]:
self.predictor.train()
history = {"loss": [], "step": []}
logger.info(f"Starting phase predictor training: {self.num_iterations} iterations (from step {start_step})")
for step in range(start_step, self.num_iterations):
P0 = self._generate_nsgf_samples(self.batch_size)
P1 = self.data_loader.sample_target(self.batch_size, self.device)
t = torch.rand(self.batch_size, device=self.device)
if P0.dim() == 4:
t_expand = t.view(-1, 1, 1, 1)
else:
t_expand = t.unsqueeze(-1)
X_t = (1 - t_expand) * P0 + t_expand * P1
t_pred = self.predictor(X_t)
loss = ((t_pred - t) ** 2).mean()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if (step + 1) % 1000 == 0 or step == start_step:
loss_val = loss.item()
history["loss"].append(loss_val)
history["step"].append(step + 1)
logger.info(f" Step {step + 1}/{self.num_iterations}, Loss: {loss_val:.6f}")
if (step + 1) % self.checkpoint_every == 0:
_save_checkpoint(
os.path.join(self.checkpoint_dir, "predictor_checkpoint.pt"),
model_state=self.predictor.state_dict(),
optimizer_state=self.optimizer.state_dict(),
step=step + 1,
history=history,
)
logger.info("Phase predictor training complete.")
return history
class NSGFPlusPlusTrainer:
"""End-to-end NSGF++ trainer (Algorithm 3 / Appendix D).
Saves checkpoints after each phase so training can be resumed.
"""
def __init__(self, nsgf_model: nn.Module, nsf_model: nn.Module,
phase_predictor: PhaseTransitionPredictor,
data_loader: DatasetLoader, config: dict, device: str = "cpu",
checkpoint_dir: str = "checkpoints"):
self.nsgf_model = nsgf_model
self.nsf_model = nsf_model
self.phase_predictor = phase_predictor
self.data_loader = data_loader
self.config = config
self.device = device
self.checkpoint_dir = checkpoint_dir
def train_all(self, resume_phase: int = 1) -> Dict[str, Any]:
"""Train all phases. resume_phase: 1=start from NSGF, 2=skip to NSF, 3=skip to predictor."""
results = {}
os.makedirs(self.checkpoint_dir, exist_ok=True)
if resume_phase <= 1:
logger.info("=" * 60)
logger.info("Phase 1: Training NSGF model")
logger.info("=" * 60)
nsgf_trainer = NSGFTrainer(
model=self.nsgf_model, data_loader=self.data_loader,
config=self.config, device=self.device,
checkpoint_dir=self.checkpoint_dir,
)
nsgf_trainer.build_trajectory_pool()
results["nsgf"] = nsgf_trainer.train()
_save_checkpoint(
os.path.join(self.checkpoint_dir, "phase1_complete.pt"),
nsgf_model_state=self.nsgf_model.state_dict(),
phase=1,
)
del nsgf_trainer.pool
if self.device != "cpu":
torch.cuda.empty_cache()
else:
logger.info(f"Skipping Phase 1 (resuming from phase {resume_phase})")
if resume_phase <= 2:
logger.info("=" * 60)
logger.info("Phase 2: Training NSF (Neural Straight Flow) model")
logger.info("=" * 60)
nsgf_steps = self.config.get("sinkhorn", {}).get("num_steps", 5)
nsf_trainer = NSFTrainer(
model=self.nsf_model, nsgf_model=self.nsgf_model,
data_loader=self.data_loader, config=self.config,
nsgf_num_steps=nsgf_steps, device=self.device,
checkpoint_dir=self.checkpoint_dir,
)
results["nsf"] = nsf_trainer.train()
_save_checkpoint(
os.path.join(self.checkpoint_dir, "phase2_complete.pt"),
nsgf_model_state=self.nsgf_model.state_dict(),
nsf_model_state=self.nsf_model.state_dict(),
phase=2,
)
else:
logger.info(f"Skipping Phase 2 (resuming from phase {resume_phase})")
if resume_phase <= 3:
logger.info("=" * 60)
logger.info("Phase 3: Training phase-transition time predictor")
logger.info("=" * 60)
nsgf_steps = self.config.get("sinkhorn", {}).get("num_steps", 5)
pt_trainer = PhaseTransitionTrainer(
predictor=self.phase_predictor, nsgf_model=self.nsgf_model,
data_loader=self.data_loader, config=self.config,
nsgf_num_steps=nsgf_steps, device=self.device,
checkpoint_dir=self.checkpoint_dir,
)
results["phase_predictor"] = pt_trainer.train()
_save_checkpoint(
os.path.join(self.checkpoint_dir, "phase3_complete.pt"),
nsgf_model_state=self.nsgf_model.state_dict(),
nsf_model_state=self.nsf_model.state_dict(),
predictor_state=self.phase_predictor.state_dict(),
phase=3,
)
logger.info("=" * 60)
logger.info("NSGF++ training complete!")
logger.info("=" * 60)
return results