"""trainer.py — Training procedures for NSGF and NSGF++. Implements: 1. Trajectory pool construction (Phase 1: Sinkhorn gradient flow) 2. NSGF velocity field matching training 3. NSF (Neural Straight Flow) training for NSGF++ 4. Phase-transition time predictor training 5. End-to-end NSGF++ training pipeline 6. Checkpointing and resume support Reference: arXiv:2401.14069, Section 4.2–4.4, Appendix D, E """ import os import logging import torch import torch.nn as nn import torch.optim as optim from typing import Optional, Dict, Any, Tuple from dataset_loader import DatasetLoader from sinkhorn_flow import ( SinkhornPotentialComputer, SinkhornGradientFlow, TrajectoryPool, ) from model import VelocityMLP, VelocityUNet, PhaseTransitionPredictor logger = logging.getLogger(__name__) def _save_checkpoint(path: str, **kwargs): """Save a checkpoint dict to disk.""" os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True) torch.save(kwargs, path) logger.info(f"Checkpoint saved: {path}") class NSGFTrainer: """Trainer for the Neural Sinkhorn Gradient Flow model. Loss (Eq. 14): L(θ) = E_{(x,v,t) ~ pool} ||v_θ(x, t) - v̂(x)||² """ def __init__(self, model: nn.Module, data_loader: DatasetLoader, config: dict, device: str = "cpu", checkpoint_dir: str = "checkpoints"): self.model = model.to(device) self.data_loader = data_loader self.config = config self.device = device self.checkpoint_dir = checkpoint_dir sink_cfg = config.get("sinkhorn", {}) self.potential_computer = SinkhornPotentialComputer( blur=sink_cfg.get("blur", 0.5), scaling=sink_cfg.get("scaling", 0.80), ) self.gradient_flow = SinkhornGradientFlow( potential_computer=self.potential_computer, eta=sink_cfg.get("eta", 1.0), num_steps=sink_cfg.get("num_steps", 5), ) self.pool = TrajectoryPool(max_size=5_000_000) train_cfg = config.get("training", config.get("nsgf_training", {})) self.num_iterations = train_cfg.get("num_iterations", 20000) self.train_batch_size = train_cfg.get("batch_size", 256) self.lr = train_cfg.get("learning_rate", 1e-3) self.optimizer = optim.Adam( self.model.parameters(), lr=self.lr, betas=(train_cfg.get("beta1", 0.9), train_cfg.get("beta2", 0.999)), weight_decay=train_cfg.get("weight_decay", 0.0), ) self.checkpoint_every = config.get("checkpoint_every", 5000) def build_trajectory_pool(self, num_batches: Optional[int] = None): if num_batches is None: num_batches = self.config.get("pool", {}).get("num_batches", 200) sink_batch_size = self.config.get("sinkhorn", {}).get("batch_size", 256) logger.info( f"Building trajectory pool: {num_batches} batches × " f"{sink_batch_size} samples × {self.gradient_flow.num_steps} steps" ) for batch_idx in range(num_batches): X0 = self.data_loader.sample_source(sink_batch_size, self.device) Y = self.data_loader.sample_target(sink_batch_size, self.device) _, trajectory = self.gradient_flow.run_flow(X0, Y, store_trajectory=True) self.pool.add_trajectory(trajectory) if (batch_idx + 1) % max(1, num_batches // 10) == 0: logger.info(f" Pool building: {batch_idx + 1}/{num_batches}, pool size: {len(self.pool)}") logger.info(f"Trajectory pool built. Total entries: {len(self.pool)}") # Free GPU memory used during Sinkhorn computation if self.device != "cpu": torch.cuda.empty_cache() # Pre-concatenate for O(1) sampling during training self.pool.finalize() logger.info("Trajectory pool finalized (pre-concatenated for fast sampling).") def train(self, start_step: int = 0) -> Dict[str, list]: self.model.train() history = {"loss": [], "step": []} logger.info(f"Starting NSGF velocity field matching: {self.num_iterations} iterations (from step {start_step})") for step in range(start_step, self.num_iterations): x_batch, v_batch, t_batch = self.pool.sample(self.train_batch_size, self.device) t_normalized = t_batch / max(self.gradient_flow.num_steps, 1.0) v_pred = self.model(x_batch, t_normalized) loss = ((v_pred - v_batch) ** 2).mean() self.optimizer.zero_grad() loss.backward() self.optimizer.step() if (step + 1) % 500 == 0 or step == start_step: loss_val = loss.item() history["loss"].append(loss_val) history["step"].append(step + 1) logger.info(f" Step {step + 1}/{self.num_iterations}, Loss: {loss_val:.6f}") if (step + 1) % self.checkpoint_every == 0: _save_checkpoint( os.path.join(self.checkpoint_dir, "nsgf_checkpoint.pt"), model_state=self.model.state_dict(), optimizer_state=self.optimizer.state_dict(), step=step + 1, history=history, ) logger.info("NSGF training complete.") return history class NSFTrainer: """Trainer for Neural Straight Flow (Phase 2 of NSGF++). Straight flow: X_t = (1-t)*P_0 + t*P_1, target velocity = P_1 - P_0 """ def __init__(self, model: nn.Module, nsgf_model: nn.Module, data_loader: DatasetLoader, config: dict, nsgf_num_steps: int = 5, device: str = "cpu", checkpoint_dir: str = "checkpoints"): self.model = model.to(device) self.nsgf_model = nsgf_model.to(device) self.nsgf_model.eval() self.data_loader = data_loader self.config = config self.device = device self.nsgf_num_steps = nsgf_num_steps self.checkpoint_dir = checkpoint_dir train_cfg = config.get("nsf_training", config.get("training", {})) self.num_iterations = train_cfg.get("num_iterations", 100000) self.train_batch_size = train_cfg.get("batch_size", 128) self.lr = train_cfg.get("learning_rate", 1e-4) self.optimizer = optim.Adam( self.model.parameters(), lr=self.lr, betas=(train_cfg.get("beta1", 0.9), train_cfg.get("beta2", 0.999)), weight_decay=train_cfg.get("weight_decay", 0.0), ) self.checkpoint_every = config.get("checkpoint_every", 5000) @torch.no_grad() def _generate_nsgf_samples(self, n: int) -> torch.Tensor: X = self.data_loader.sample_source(n, self.device) dt = 1.0 / self.nsgf_num_steps for step in range(self.nsgf_num_steps): t = torch.full((n,), step * dt, device=self.device) v = self.nsgf_model(X, t) X = X + dt * v return X def train(self, start_step: int = 0) -> Dict[str, list]: self.model.train() history = {"loss": [], "step": []} logger.info(f"Starting NSF training: {self.num_iterations} iterations (from step {start_step})") for step in range(start_step, self.num_iterations): P0 = self._generate_nsgf_samples(self.train_batch_size) P1 = self.data_loader.sample_target(self.train_batch_size, self.device) t = torch.rand(self.train_batch_size, device=self.device) if P0.dim() == 2: t_expand = t.unsqueeze(-1) else: t_expand = t.view(-1, 1, 1, 1) X_t = (1 - t_expand) * P0 + t_expand * P1 v_target = P1 - P0 v_pred = self.model(X_t, t) loss = ((v_pred - v_target) ** 2).mean() self.optimizer.zero_grad() loss.backward() self.optimizer.step() if (step + 1) % 500 == 0 or step == start_step: loss_val = loss.item() history["loss"].append(loss_val) history["step"].append(step + 1) logger.info(f" Step {step + 1}/{self.num_iterations}, Loss: {loss_val:.6f}") if (step + 1) % self.checkpoint_every == 0: _save_checkpoint( os.path.join(self.checkpoint_dir, "nsf_checkpoint.pt"), model_state=self.model.state_dict(), optimizer_state=self.optimizer.state_dict(), step=step + 1, history=history, ) logger.info("NSF training complete.") return history class PhaseTransitionTrainer: """Trainer for the phase-transition time predictor. Loss: L(ϕ) = E_{t~U(0,1)} ||t - t_ϕ(X_t)||² """ def __init__(self, predictor: PhaseTransitionPredictor, nsgf_model: nn.Module, data_loader: DatasetLoader, config: dict, nsgf_num_steps: int = 5, device: str = "cpu", checkpoint_dir: str = "checkpoints"): self.predictor = predictor.to(device) self.nsgf_model = nsgf_model.to(device) self.nsgf_model.eval() self.data_loader = data_loader self.config = config self.device = device self.nsgf_num_steps = nsgf_num_steps self.checkpoint_dir = checkpoint_dir tp_cfg = config.get("time_predictor", {}) self.num_iterations = tp_cfg.get("num_iterations", 40000) self.batch_size = tp_cfg.get("batch_size", 128) self.lr = tp_cfg.get("learning_rate", 1e-4) self.optimizer = optim.Adam(self.predictor.parameters(), lr=self.lr, betas=(0.9, 0.999)) self.checkpoint_every = config.get("checkpoint_every", 5000) @torch.no_grad() def _generate_nsgf_samples(self, n: int) -> torch.Tensor: X = self.data_loader.sample_source(n, self.device) dt = 1.0 / self.nsgf_num_steps for step in range(self.nsgf_num_steps): t = torch.full((n,), step * dt, device=self.device) v = self.nsgf_model(X, t) X = X + dt * v return X def train(self, start_step: int = 0) -> Dict[str, list]: self.predictor.train() history = {"loss": [], "step": []} logger.info(f"Starting phase predictor training: {self.num_iterations} iterations (from step {start_step})") for step in range(start_step, self.num_iterations): P0 = self._generate_nsgf_samples(self.batch_size) P1 = self.data_loader.sample_target(self.batch_size, self.device) t = torch.rand(self.batch_size, device=self.device) if P0.dim() == 4: t_expand = t.view(-1, 1, 1, 1) else: t_expand = t.unsqueeze(-1) X_t = (1 - t_expand) * P0 + t_expand * P1 t_pred = self.predictor(X_t) loss = ((t_pred - t) ** 2).mean() self.optimizer.zero_grad() loss.backward() self.optimizer.step() if (step + 1) % 1000 == 0 or step == start_step: loss_val = loss.item() history["loss"].append(loss_val) history["step"].append(step + 1) logger.info(f" Step {step + 1}/{self.num_iterations}, Loss: {loss_val:.6f}") if (step + 1) % self.checkpoint_every == 0: _save_checkpoint( os.path.join(self.checkpoint_dir, "predictor_checkpoint.pt"), model_state=self.predictor.state_dict(), optimizer_state=self.optimizer.state_dict(), step=step + 1, history=history, ) logger.info("Phase predictor training complete.") return history class NSGFPlusPlusTrainer: """End-to-end NSGF++ trainer (Algorithm 3 / Appendix D). Saves checkpoints after each phase so training can be resumed. """ def __init__(self, nsgf_model: nn.Module, nsf_model: nn.Module, phase_predictor: PhaseTransitionPredictor, data_loader: DatasetLoader, config: dict, device: str = "cpu", checkpoint_dir: str = "checkpoints"): self.nsgf_model = nsgf_model self.nsf_model = nsf_model self.phase_predictor = phase_predictor self.data_loader = data_loader self.config = config self.device = device self.checkpoint_dir = checkpoint_dir def train_all(self, resume_phase: int = 1) -> Dict[str, Any]: """Train all phases. resume_phase: 1=start from NSGF, 2=skip to NSF, 3=skip to predictor.""" results = {} os.makedirs(self.checkpoint_dir, exist_ok=True) if resume_phase <= 1: logger.info("=" * 60) logger.info("Phase 1: Training NSGF model") logger.info("=" * 60) nsgf_trainer = NSGFTrainer( model=self.nsgf_model, data_loader=self.data_loader, config=self.config, device=self.device, checkpoint_dir=self.checkpoint_dir, ) nsgf_trainer.build_trajectory_pool() results["nsgf"] = nsgf_trainer.train() _save_checkpoint( os.path.join(self.checkpoint_dir, "phase1_complete.pt"), nsgf_model_state=self.nsgf_model.state_dict(), phase=1, ) del nsgf_trainer.pool if self.device != "cpu": torch.cuda.empty_cache() else: logger.info(f"Skipping Phase 1 (resuming from phase {resume_phase})") if resume_phase <= 2: logger.info("=" * 60) logger.info("Phase 2: Training NSF (Neural Straight Flow) model") logger.info("=" * 60) nsgf_steps = self.config.get("sinkhorn", {}).get("num_steps", 5) nsf_trainer = NSFTrainer( model=self.nsf_model, nsgf_model=self.nsgf_model, data_loader=self.data_loader, config=self.config, nsgf_num_steps=nsgf_steps, device=self.device, checkpoint_dir=self.checkpoint_dir, ) results["nsf"] = nsf_trainer.train() _save_checkpoint( os.path.join(self.checkpoint_dir, "phase2_complete.pt"), nsgf_model_state=self.nsgf_model.state_dict(), nsf_model_state=self.nsf_model.state_dict(), phase=2, ) else: logger.info(f"Skipping Phase 2 (resuming from phase {resume_phase})") if resume_phase <= 3: logger.info("=" * 60) logger.info("Phase 3: Training phase-transition time predictor") logger.info("=" * 60) nsgf_steps = self.config.get("sinkhorn", {}).get("num_steps", 5) pt_trainer = PhaseTransitionTrainer( predictor=self.phase_predictor, nsgf_model=self.nsgf_model, data_loader=self.data_loader, config=self.config, nsgf_num_steps=nsgf_steps, device=self.device, checkpoint_dir=self.checkpoint_dir, ) results["phase_predictor"] = pt_trainer.train() _save_checkpoint( os.path.join(self.checkpoint_dir, "phase3_complete.pt"), nsgf_model_state=self.nsgf_model.state_dict(), nsf_model_state=self.nsf_model.state_dict(), predictor_state=self.phase_predictor.state_dict(), phase=3, ) logger.info("=" * 60) logger.info("NSGF++ training complete!") logger.info("=" * 60) return results