""" Rectified Flow Training for LiquidDiffusion Training Objective (Rectified Flow): x_t = (1-t)*x0 + t*x1, t ~ U[0,1], x1 ~ N(0,I) v_target = x1 - x0 (constant velocity) L = E[||v_θ(x_t, t) - v_target||²] (simple MSE) Sampling (Euler ODE): Start from x_1 ~ N(0,I), integrate backward: x_{t-dt} = x_t - v_θ(x_t, t) * dt References: [1] Liu et al., "Flow Straight and Fast: Rectified Flow", ICLR 2023 [2] Lee et al., "Improving the Training of Rectified Flows", 2024 """ import math import copy import os import time import json import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from torchvision import transforms from torchvision.utils import save_image, make_grid class RectifiedFlowTrainer: """Trainer for LiquidDiffusion using Rectified Flow objective. Features: - Simple MSE velocity loss (no noise schedule to tune) - Optional logit-normal time sampling (from SD3) - EMA model for stable sampling - Gradient clipping, mixed precision - Checkpoint save/load with resume support """ def __init__(self, model, optimizer=None, lr=1e-4, weight_decay=0.01, ema_decay=0.9999, grad_clip=1.0, time_sampling="logit_normal", logit_normal_mean=0.0, logit_normal_std=1.0, device="cuda", use_amp=True, amp_dtype="float16"): self.device = device self.model = model.to(device) self.ema_decay = ema_decay self.grad_clip = grad_clip self.time_sampling = time_sampling self.logit_normal_mean = logit_normal_mean self.logit_normal_std = logit_normal_std # AMP only on CUDA self.use_amp = use_amp and (device == "cuda") self.amp_dtype = getattr(torch, amp_dtype) if self.use_amp else torch.float32 if optimizer is None: self.optimizer = torch.optim.AdamW( model.parameters(), lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999)) else: self.optimizer = optimizer # GradScaler only when AMP is active if self.use_amp and amp_dtype == "float16": self.scaler = torch.amp.GradScaler("cuda", enabled=True) else: self.scaler = torch.amp.GradScaler("cuda", enabled=False) self.ema_model = self._build_ema() self.step = 0 self.losses = [] def _build_ema(self): """Create EMA copy of model.""" ema = copy.deepcopy(self.model) ema.eval() for p in ema.parameters(): p.requires_grad_(False) return ema @torch.no_grad() def _update_ema(self): """Update EMA weights: ema = decay * ema + (1-decay) * model""" for ema_p, model_p in zip(self.ema_model.parameters(), self.model.parameters()): ema_p.data.mul_(self.ema_decay).add_(model_p.data, alpha=1 - self.ema_decay) def _sample_time(self, batch_size): """Sample timesteps. logit_normal puts more weight near t=0.5.""" eps = 1e-5 if self.time_sampling == "uniform": return torch.rand(batch_size, device=self.device) * (1 - 2*eps) + eps elif self.time_sampling == "logit_normal": u = torch.randn(batch_size, device=self.device) * self.logit_normal_std + self.logit_normal_mean return torch.sigmoid(u).clamp(eps, 1 - eps) raise ValueError(f"Unknown time_sampling: {self.time_sampling}") def train_step(self, x0): """Single training step. x0: [B,C,H,W] images in [-1,1].""" self.model.train() x0 = x0.to(self.device) x1 = torch.randn_like(x0) t = self._sample_time(x0.shape[0]) t_expand = t[:, None, None, None] x_t = (1 - t_expand) * x0 + t_expand * x1 v_target = x1 - x0 # Forward with optional AMP if self.use_amp: with torch.amp.autocast("cuda", dtype=self.amp_dtype): v_pred = self.model(x_t, t) loss = F.mse_loss(v_pred, v_target) else: v_pred = self.model(x_t, t) loss = F.mse_loss(v_pred, v_target) # Backward self.optimizer.zero_grad(set_to_none=True) self.scaler.scale(loss).backward() if self.grad_clip > 0: self.scaler.unscale_(self.optimizer) grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) else: grad_norm = torch.tensor(0.0) self.scaler.step(self.optimizer) self.scaler.update() self._update_ema() self.step += 1 loss_val = loss.item() self.losses.append(loss_val) return { 'loss': loss_val, 'grad_norm': grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm, 'step': self.step, } @torch.no_grad() def sample(self, batch_size=4, image_size=256, channels=3, num_steps=50, use_ema=True): """Generate images via Euler ODE integration from noise → data.""" model = self.ema_model if use_ema else self.model model.eval() z = torch.randn(batch_size, channels, image_size, image_size, device=self.device) dt = 1.0 / num_steps for i in range(num_steps, 0, -1): t = torch.full((batch_size,), i / num_steps, device=self.device) if self.use_amp: with torch.amp.autocast("cuda", dtype=self.amp_dtype): v = model(z, t) v = v.float() else: v = model(z, t) z = z - v * dt return z.clamp(-1, 1) def save_checkpoint(self, path, extra=None): """Save model, EMA, optimizer, scaler, and training state.""" ckpt = { 'model': self.model.state_dict(), 'ema_model': self.ema_model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'scaler': self.scaler.state_dict(), 'step': self.step, 'losses': self.losses[-1000:], } if extra: ckpt.update(extra) dir_path = os.path.dirname(path) if dir_path: os.makedirs(dir_path, exist_ok=True) torch.save(ckpt, path) def load_checkpoint(self, path): """Load checkpoint and resume training.""" ckpt = torch.load(path, map_location=self.device, weights_only=False) self.model.load_state_dict(ckpt['model']) self.ema_model.load_state_dict(ckpt['ema_model']) self.optimizer.load_state_dict(ckpt['optimizer']) if 'scaler' in ckpt: self.scaler.load_state_dict(ckpt['scaler']) self.step = ckpt.get('step', 0) self.losses = ckpt.get('losses', []) print(f"Resumed from step {self.step}") class ImageDataset(Dataset): """Image dataset from local folder or HuggingFace Hub. Usage: # From HuggingFace ds = ImageDataset("huggan/CelebA-HQ", image_size=256) # From local folder ds = ImageDataset("/path/to/images", image_size=256) # With pre-loaded HF dataset from datasets import load_dataset hf_ds = load_dataset("huggan/CelebA-HQ", split="train") ds = ImageDataset(None, image_size=256, hf_dataset=hf_ds) """ def __init__(self, source, image_size=256, split="train", image_column="image", max_samples=None, hf_dataset=None): self.image_size = image_size self.image_column = image_column self.transform = transforms.Compose([ transforms.Resize(image_size, interpolation=transforms.InterpolationMode.LANCZOS), transforms.CenterCrop(image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), # → [-1, 1] ]) if hf_dataset is not None: self.data = hf_dataset self.mode = "hf" elif source and os.path.isdir(source): from glob import glob self.files = [] for ext in ['*.png', '*.jpg', '*.jpeg', '*.webp', '*.bmp']: self.files.extend(glob(os.path.join(source, '**', ext), recursive=True)) self.files.sort() if max_samples: self.files = self.files[:max_samples] self.mode = "folder" else: from datasets import load_dataset self.data = load_dataset(source, split=split) if max_samples: self.data = self.data.select(range(min(max_samples, len(self.data)))) self.mode = "hf" def __len__(self): return len(self.files) if self.mode == "folder" else len(self.data) def __getitem__(self, idx): if self.mode == "folder": from PIL import Image img = Image.open(self.files[idx]).convert("RGB") else: img = self.data[idx][self.image_column] if not hasattr(img, 'convert'): from PIL import Image as PILImage img = PILImage.fromarray(img) img = img.convert("RGB") return self.transform(img) def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps): """Cosine annealing with linear warmup — standard for diffusion training.""" def lr_lambda(step): if step < num_warmup_steps: return float(step) / float(max(1, num_warmup_steps)) progress = float(step - num_warmup_steps) / float( max(1, num_training_steps - num_warmup_steps)) return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)