| """trainer.py — Training procedures for NSGF and NSGF++. |
| |
| Implements: |
| 1. Trajectory pool construction (Phase 1: Sinkhorn gradient flow) |
| 2. NSGF velocity field matching training |
| 3. NSF (Neural Straight Flow) training for NSGF++ |
| 4. Phase-transition time predictor training |
| 5. End-to-end NSGF++ training pipeline |
| 6. Checkpointing and resume support |
| |
| Reference: arXiv:2401.14069, Section 4.2–4.4, Appendix D, E |
| """ |
|
|
| import os |
| import logging |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from typing import Optional, Dict, Any, Tuple |
|
|
| from dataset_loader import DatasetLoader |
| from sinkhorn_flow import ( |
| SinkhornPotentialComputer, SinkhornGradientFlow, TrajectoryPool, |
| ) |
| from model import VelocityMLP, VelocityUNet, PhaseTransitionPredictor |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def _save_checkpoint(path: str, **kwargs): |
| """Save a checkpoint dict to disk.""" |
| os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True) |
| torch.save(kwargs, path) |
| logger.info(f"Checkpoint saved: {path}") |
|
|
|
|
| class NSGFTrainer: |
| """Trainer for the Neural Sinkhorn Gradient Flow model. |
| |
| Loss (Eq. 14): L(θ) = E_{(x,v,t) ~ pool} ||v_θ(x, t) - v̂(x)||² |
| """ |
| def __init__(self, model: nn.Module, data_loader: DatasetLoader, |
| config: dict, device: str = "cpu", |
| checkpoint_dir: str = "checkpoints"): |
| self.model = model.to(device) |
| self.data_loader = data_loader |
| self.config = config |
| self.device = device |
| self.checkpoint_dir = checkpoint_dir |
|
|
| sink_cfg = config.get("sinkhorn", {}) |
| self.potential_computer = SinkhornPotentialComputer( |
| blur=sink_cfg.get("blur", 0.5), scaling=sink_cfg.get("scaling", 0.80), |
| ) |
| self.gradient_flow = SinkhornGradientFlow( |
| potential_computer=self.potential_computer, |
| eta=sink_cfg.get("eta", 1.0), num_steps=sink_cfg.get("num_steps", 5), |
| ) |
| self.pool = TrajectoryPool(max_size=5_000_000) |
|
|
| train_cfg = config.get("training", config.get("nsgf_training", {})) |
| self.num_iterations = train_cfg.get("num_iterations", 20000) |
| self.train_batch_size = train_cfg.get("batch_size", 256) |
| self.lr = train_cfg.get("learning_rate", 1e-3) |
| self.optimizer = optim.Adam( |
| self.model.parameters(), lr=self.lr, |
| betas=(train_cfg.get("beta1", 0.9), train_cfg.get("beta2", 0.999)), |
| weight_decay=train_cfg.get("weight_decay", 0.0), |
| ) |
| self.checkpoint_every = config.get("checkpoint_every", 5000) |
|
|
| def build_trajectory_pool(self, num_batches: Optional[int] = None): |
| if num_batches is None: |
| num_batches = self.config.get("pool", {}).get("num_batches", 200) |
| sink_batch_size = self.config.get("sinkhorn", {}).get("batch_size", 256) |
| logger.info( |
| f"Building trajectory pool: {num_batches} batches × " |
| f"{sink_batch_size} samples × {self.gradient_flow.num_steps} steps" |
| ) |
| for batch_idx in range(num_batches): |
| X0 = self.data_loader.sample_source(sink_batch_size, self.device) |
| Y = self.data_loader.sample_target(sink_batch_size, self.device) |
| _, trajectory = self.gradient_flow.run_flow(X0, Y, store_trajectory=True) |
| self.pool.add_trajectory(trajectory) |
| if (batch_idx + 1) % max(1, num_batches // 10) == 0: |
| logger.info(f" Pool building: {batch_idx + 1}/{num_batches}, pool size: {len(self.pool)}") |
| logger.info(f"Trajectory pool built. Total entries: {len(self.pool)}") |
| |
| if self.device != "cpu": |
| torch.cuda.empty_cache() |
| |
| self.pool.finalize() |
| logger.info("Trajectory pool finalized (pre-concatenated for fast sampling).") |
|
|
| def train(self, start_step: int = 0) -> Dict[str, list]: |
| self.model.train() |
| history = {"loss": [], "step": []} |
| logger.info(f"Starting NSGF velocity field matching: {self.num_iterations} iterations (from step {start_step})") |
| for step in range(start_step, self.num_iterations): |
| x_batch, v_batch, t_batch = self.pool.sample(self.train_batch_size, self.device) |
| t_normalized = t_batch / max(self.gradient_flow.num_steps, 1.0) |
| v_pred = self.model(x_batch, t_normalized) |
| loss = ((v_pred - v_batch) ** 2).mean() |
| self.optimizer.zero_grad() |
| loss.backward() |
| self.optimizer.step() |
| if (step + 1) % 500 == 0 or step == start_step: |
| loss_val = loss.item() |
| history["loss"].append(loss_val) |
| history["step"].append(step + 1) |
| logger.info(f" Step {step + 1}/{self.num_iterations}, Loss: {loss_val:.6f}") |
| if (step + 1) % self.checkpoint_every == 0: |
| _save_checkpoint( |
| os.path.join(self.checkpoint_dir, "nsgf_checkpoint.pt"), |
| model_state=self.model.state_dict(), |
| optimizer_state=self.optimizer.state_dict(), |
| step=step + 1, |
| history=history, |
| ) |
| logger.info("NSGF training complete.") |
| return history |
|
|
|
|
| class NSFTrainer: |
| """Trainer for Neural Straight Flow (Phase 2 of NSGF++). |
| |
| Straight flow: X_t = (1-t)*P_0 + t*P_1, target velocity = P_1 - P_0 |
| """ |
| def __init__(self, model: nn.Module, nsgf_model: nn.Module, |
| data_loader: DatasetLoader, config: dict, |
| nsgf_num_steps: int = 5, device: str = "cpu", |
| checkpoint_dir: str = "checkpoints"): |
| self.model = model.to(device) |
| self.nsgf_model = nsgf_model.to(device) |
| self.nsgf_model.eval() |
| self.data_loader = data_loader |
| self.config = config |
| self.device = device |
| self.nsgf_num_steps = nsgf_num_steps |
| self.checkpoint_dir = checkpoint_dir |
|
|
| train_cfg = config.get("nsf_training", config.get("training", {})) |
| self.num_iterations = train_cfg.get("num_iterations", 100000) |
| self.train_batch_size = train_cfg.get("batch_size", 128) |
| self.lr = train_cfg.get("learning_rate", 1e-4) |
| self.optimizer = optim.Adam( |
| self.model.parameters(), lr=self.lr, |
| betas=(train_cfg.get("beta1", 0.9), train_cfg.get("beta2", 0.999)), |
| weight_decay=train_cfg.get("weight_decay", 0.0), |
| ) |
| self.checkpoint_every = config.get("checkpoint_every", 5000) |
|
|
| @torch.no_grad() |
| def _generate_nsgf_samples(self, n: int) -> torch.Tensor: |
| X = self.data_loader.sample_source(n, self.device) |
| dt = 1.0 / self.nsgf_num_steps |
| for step in range(self.nsgf_num_steps): |
| t = torch.full((n,), step * dt, device=self.device) |
| v = self.nsgf_model(X, t) |
| X = X + dt * v |
| return X |
|
|
| def train(self, start_step: int = 0) -> Dict[str, list]: |
| self.model.train() |
| history = {"loss": [], "step": []} |
| logger.info(f"Starting NSF training: {self.num_iterations} iterations (from step {start_step})") |
| for step in range(start_step, self.num_iterations): |
| P0 = self._generate_nsgf_samples(self.train_batch_size) |
| P1 = self.data_loader.sample_target(self.train_batch_size, self.device) |
| t = torch.rand(self.train_batch_size, device=self.device) |
| if P0.dim() == 2: |
| t_expand = t.unsqueeze(-1) |
| else: |
| t_expand = t.view(-1, 1, 1, 1) |
| X_t = (1 - t_expand) * P0 + t_expand * P1 |
| v_target = P1 - P0 |
| v_pred = self.model(X_t, t) |
| loss = ((v_pred - v_target) ** 2).mean() |
| self.optimizer.zero_grad() |
| loss.backward() |
| self.optimizer.step() |
| if (step + 1) % 500 == 0 or step == start_step: |
| loss_val = loss.item() |
| history["loss"].append(loss_val) |
| history["step"].append(step + 1) |
| logger.info(f" Step {step + 1}/{self.num_iterations}, Loss: {loss_val:.6f}") |
| if (step + 1) % self.checkpoint_every == 0: |
| _save_checkpoint( |
| os.path.join(self.checkpoint_dir, "nsf_checkpoint.pt"), |
| model_state=self.model.state_dict(), |
| optimizer_state=self.optimizer.state_dict(), |
| step=step + 1, |
| history=history, |
| ) |
| logger.info("NSF training complete.") |
| return history |
|
|
|
|
| class PhaseTransitionTrainer: |
| """Trainer for the phase-transition time predictor. |
| Loss: L(ϕ) = E_{t~U(0,1)} ||t - t_ϕ(X_t)||² |
| """ |
| def __init__(self, predictor: PhaseTransitionPredictor, nsgf_model: nn.Module, |
| data_loader: DatasetLoader, config: dict, |
| nsgf_num_steps: int = 5, device: str = "cpu", |
| checkpoint_dir: str = "checkpoints"): |
| self.predictor = predictor.to(device) |
| self.nsgf_model = nsgf_model.to(device) |
| self.nsgf_model.eval() |
| self.data_loader = data_loader |
| self.config = config |
| self.device = device |
| self.nsgf_num_steps = nsgf_num_steps |
| self.checkpoint_dir = checkpoint_dir |
| tp_cfg = config.get("time_predictor", {}) |
| self.num_iterations = tp_cfg.get("num_iterations", 40000) |
| self.batch_size = tp_cfg.get("batch_size", 128) |
| self.lr = tp_cfg.get("learning_rate", 1e-4) |
| self.optimizer = optim.Adam(self.predictor.parameters(), lr=self.lr, betas=(0.9, 0.999)) |
| self.checkpoint_every = config.get("checkpoint_every", 5000) |
|
|
| @torch.no_grad() |
| def _generate_nsgf_samples(self, n: int) -> torch.Tensor: |
| X = self.data_loader.sample_source(n, self.device) |
| dt = 1.0 / self.nsgf_num_steps |
| for step in range(self.nsgf_num_steps): |
| t = torch.full((n,), step * dt, device=self.device) |
| v = self.nsgf_model(X, t) |
| X = X + dt * v |
| return X |
|
|
| def train(self, start_step: int = 0) -> Dict[str, list]: |
| self.predictor.train() |
| history = {"loss": [], "step": []} |
| logger.info(f"Starting phase predictor training: {self.num_iterations} iterations (from step {start_step})") |
| for step in range(start_step, self.num_iterations): |
| P0 = self._generate_nsgf_samples(self.batch_size) |
| P1 = self.data_loader.sample_target(self.batch_size, self.device) |
| t = torch.rand(self.batch_size, device=self.device) |
| if P0.dim() == 4: |
| t_expand = t.view(-1, 1, 1, 1) |
| else: |
| t_expand = t.unsqueeze(-1) |
| X_t = (1 - t_expand) * P0 + t_expand * P1 |
| t_pred = self.predictor(X_t) |
| loss = ((t_pred - t) ** 2).mean() |
| self.optimizer.zero_grad() |
| loss.backward() |
| self.optimizer.step() |
| if (step + 1) % 1000 == 0 or step == start_step: |
| loss_val = loss.item() |
| history["loss"].append(loss_val) |
| history["step"].append(step + 1) |
| logger.info(f" Step {step + 1}/{self.num_iterations}, Loss: {loss_val:.6f}") |
| if (step + 1) % self.checkpoint_every == 0: |
| _save_checkpoint( |
| os.path.join(self.checkpoint_dir, "predictor_checkpoint.pt"), |
| model_state=self.predictor.state_dict(), |
| optimizer_state=self.optimizer.state_dict(), |
| step=step + 1, |
| history=history, |
| ) |
| logger.info("Phase predictor training complete.") |
| return history |
|
|
|
|
| class NSGFPlusPlusTrainer: |
| """End-to-end NSGF++ trainer (Algorithm 3 / Appendix D). |
| |
| Saves checkpoints after each phase so training can be resumed. |
| """ |
| def __init__(self, nsgf_model: nn.Module, nsf_model: nn.Module, |
| phase_predictor: PhaseTransitionPredictor, |
| data_loader: DatasetLoader, config: dict, device: str = "cpu", |
| checkpoint_dir: str = "checkpoints"): |
| self.nsgf_model = nsgf_model |
| self.nsf_model = nsf_model |
| self.phase_predictor = phase_predictor |
| self.data_loader = data_loader |
| self.config = config |
| self.device = device |
| self.checkpoint_dir = checkpoint_dir |
|
|
| def train_all(self, resume_phase: int = 1) -> Dict[str, Any]: |
| """Train all phases. resume_phase: 1=start from NSGF, 2=skip to NSF, 3=skip to predictor.""" |
| results = {} |
| os.makedirs(self.checkpoint_dir, exist_ok=True) |
|
|
| if resume_phase <= 1: |
| logger.info("=" * 60) |
| logger.info("Phase 1: Training NSGF model") |
| logger.info("=" * 60) |
| nsgf_trainer = NSGFTrainer( |
| model=self.nsgf_model, data_loader=self.data_loader, |
| config=self.config, device=self.device, |
| checkpoint_dir=self.checkpoint_dir, |
| ) |
| nsgf_trainer.build_trajectory_pool() |
| results["nsgf"] = nsgf_trainer.train() |
| _save_checkpoint( |
| os.path.join(self.checkpoint_dir, "phase1_complete.pt"), |
| nsgf_model_state=self.nsgf_model.state_dict(), |
| phase=1, |
| ) |
| del nsgf_trainer.pool |
| if self.device != "cpu": |
| torch.cuda.empty_cache() |
| else: |
| logger.info(f"Skipping Phase 1 (resuming from phase {resume_phase})") |
|
|
| if resume_phase <= 2: |
| logger.info("=" * 60) |
| logger.info("Phase 2: Training NSF (Neural Straight Flow) model") |
| logger.info("=" * 60) |
| nsgf_steps = self.config.get("sinkhorn", {}).get("num_steps", 5) |
| nsf_trainer = NSFTrainer( |
| model=self.nsf_model, nsgf_model=self.nsgf_model, |
| data_loader=self.data_loader, config=self.config, |
| nsgf_num_steps=nsgf_steps, device=self.device, |
| checkpoint_dir=self.checkpoint_dir, |
| ) |
| results["nsf"] = nsf_trainer.train() |
| _save_checkpoint( |
| os.path.join(self.checkpoint_dir, "phase2_complete.pt"), |
| nsgf_model_state=self.nsgf_model.state_dict(), |
| nsf_model_state=self.nsf_model.state_dict(), |
| phase=2, |
| ) |
| else: |
| logger.info(f"Skipping Phase 2 (resuming from phase {resume_phase})") |
|
|
| if resume_phase <= 3: |
| logger.info("=" * 60) |
| logger.info("Phase 3: Training phase-transition time predictor") |
| logger.info("=" * 60) |
| nsgf_steps = self.config.get("sinkhorn", {}).get("num_steps", 5) |
| pt_trainer = PhaseTransitionTrainer( |
| predictor=self.phase_predictor, nsgf_model=self.nsgf_model, |
| data_loader=self.data_loader, config=self.config, |
| nsgf_num_steps=nsgf_steps, device=self.device, |
| checkpoint_dir=self.checkpoint_dir, |
| ) |
| results["phase_predictor"] = pt_trainer.train() |
| _save_checkpoint( |
| os.path.join(self.checkpoint_dir, "phase3_complete.pt"), |
| nsgf_model_state=self.nsgf_model.state_dict(), |
| nsf_model_state=self.nsf_model.state_dict(), |
| predictor_state=self.phase_predictor.state_dict(), |
| phase=3, |
| ) |
|
|
| logger.info("=" * 60) |
| logger.info("NSGF++ training complete!") |
| logger.info("=" * 60) |
| return results |
|
|