krystv's picture
Upload liquid_diffusion/trainer.py
992d967 verified
"""
Rectified Flow Training for LiquidDiffusion
Training Objective (Rectified Flow):
x_t = (1-t)*x0 + t*x1, t ~ U[0,1], x1 ~ N(0,I)
v_target = x1 - x0 (constant velocity)
L = E[||v_θ(x_t, t) - v_target||²] (simple MSE)
Sampling (Euler ODE):
Start from x_1 ~ N(0,I), integrate backward:
x_{t-dt} = x_t - v_θ(x_t, t) * dt
References:
[1] Liu et al., "Flow Straight and Fast: Rectified Flow", ICLR 2023
[2] Lee et al., "Improving the Training of Rectified Flows", 2024
"""
import math
import copy
import os
import time
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.utils import save_image, make_grid
class RectifiedFlowTrainer:
"""Trainer for LiquidDiffusion using Rectified Flow objective.
Features:
- Simple MSE velocity loss (no noise schedule to tune)
- Optional logit-normal time sampling (from SD3)
- EMA model for stable sampling
- Gradient clipping, mixed precision
- Checkpoint save/load with resume support
"""
def __init__(self, model, optimizer=None, lr=1e-4, weight_decay=0.01,
ema_decay=0.9999, grad_clip=1.0, time_sampling="logit_normal",
logit_normal_mean=0.0, logit_normal_std=1.0, device="cuda",
use_amp=True, amp_dtype="float16"):
self.device = device
self.model = model.to(device)
self.ema_decay = ema_decay
self.grad_clip = grad_clip
self.time_sampling = time_sampling
self.logit_normal_mean = logit_normal_mean
self.logit_normal_std = logit_normal_std
# AMP only on CUDA
self.use_amp = use_amp and (device == "cuda")
self.amp_dtype = getattr(torch, amp_dtype) if self.use_amp else torch.float32
if optimizer is None:
self.optimizer = torch.optim.AdamW(
model.parameters(), lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999))
else:
self.optimizer = optimizer
# GradScaler only when AMP is active
if self.use_amp and amp_dtype == "float16":
self.scaler = torch.amp.GradScaler("cuda", enabled=True)
else:
self.scaler = torch.amp.GradScaler("cuda", enabled=False)
self.ema_model = self._build_ema()
self.step = 0
self.losses = []
def _build_ema(self):
"""Create EMA copy of model."""
ema = copy.deepcopy(self.model)
ema.eval()
for p in ema.parameters():
p.requires_grad_(False)
return ema
@torch.no_grad()
def _update_ema(self):
"""Update EMA weights: ema = decay * ema + (1-decay) * model"""
for ema_p, model_p in zip(self.ema_model.parameters(), self.model.parameters()):
ema_p.data.mul_(self.ema_decay).add_(model_p.data, alpha=1 - self.ema_decay)
def _sample_time(self, batch_size):
"""Sample timesteps. logit_normal puts more weight near t=0.5."""
eps = 1e-5
if self.time_sampling == "uniform":
return torch.rand(batch_size, device=self.device) * (1 - 2*eps) + eps
elif self.time_sampling == "logit_normal":
u = torch.randn(batch_size, device=self.device) * self.logit_normal_std + self.logit_normal_mean
return torch.sigmoid(u).clamp(eps, 1 - eps)
raise ValueError(f"Unknown time_sampling: {self.time_sampling}")
def train_step(self, x0):
"""Single training step. x0: [B,C,H,W] images in [-1,1]."""
self.model.train()
x0 = x0.to(self.device)
x1 = torch.randn_like(x0)
t = self._sample_time(x0.shape[0])
t_expand = t[:, None, None, None]
x_t = (1 - t_expand) * x0 + t_expand * x1
v_target = x1 - x0
# Forward with optional AMP
if self.use_amp:
with torch.amp.autocast("cuda", dtype=self.amp_dtype):
v_pred = self.model(x_t, t)
loss = F.mse_loss(v_pred, v_target)
else:
v_pred = self.model(x_t, t)
loss = F.mse_loss(v_pred, v_target)
# Backward
self.optimizer.zero_grad(set_to_none=True)
self.scaler.scale(loss).backward()
if self.grad_clip > 0:
self.scaler.unscale_(self.optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
else:
grad_norm = torch.tensor(0.0)
self.scaler.step(self.optimizer)
self.scaler.update()
self._update_ema()
self.step += 1
loss_val = loss.item()
self.losses.append(loss_val)
return {
'loss': loss_val,
'grad_norm': grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm,
'step': self.step,
}
@torch.no_grad()
def sample(self, batch_size=4, image_size=256, channels=3, num_steps=50, use_ema=True):
"""Generate images via Euler ODE integration from noise → data."""
model = self.ema_model if use_ema else self.model
model.eval()
z = torch.randn(batch_size, channels, image_size, image_size, device=self.device)
dt = 1.0 / num_steps
for i in range(num_steps, 0, -1):
t = torch.full((batch_size,), i / num_steps, device=self.device)
if self.use_amp:
with torch.amp.autocast("cuda", dtype=self.amp_dtype):
v = model(z, t)
v = v.float()
else:
v = model(z, t)
z = z - v * dt
return z.clamp(-1, 1)
def save_checkpoint(self, path, extra=None):
"""Save model, EMA, optimizer, scaler, and training state."""
ckpt = {
'model': self.model.state_dict(),
'ema_model': self.ema_model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'scaler': self.scaler.state_dict(),
'step': self.step,
'losses': self.losses[-1000:],
}
if extra:
ckpt.update(extra)
dir_path = os.path.dirname(path)
if dir_path:
os.makedirs(dir_path, exist_ok=True)
torch.save(ckpt, path)
def load_checkpoint(self, path):
"""Load checkpoint and resume training."""
ckpt = torch.load(path, map_location=self.device, weights_only=False)
self.model.load_state_dict(ckpt['model'])
self.ema_model.load_state_dict(ckpt['ema_model'])
self.optimizer.load_state_dict(ckpt['optimizer'])
if 'scaler' in ckpt:
self.scaler.load_state_dict(ckpt['scaler'])
self.step = ckpt.get('step', 0)
self.losses = ckpt.get('losses', [])
print(f"Resumed from step {self.step}")
class ImageDataset(Dataset):
"""Image dataset from local folder or HuggingFace Hub.
Usage:
# From HuggingFace
ds = ImageDataset("huggan/CelebA-HQ", image_size=256)
# From local folder
ds = ImageDataset("/path/to/images", image_size=256)
# With pre-loaded HF dataset
from datasets import load_dataset
hf_ds = load_dataset("huggan/CelebA-HQ", split="train")
ds = ImageDataset(None, image_size=256, hf_dataset=hf_ds)
"""
def __init__(self, source, image_size=256, split="train",
image_column="image", max_samples=None, hf_dataset=None):
self.image_size = image_size
self.image_column = image_column
self.transform = transforms.Compose([
transforms.Resize(image_size, interpolation=transforms.InterpolationMode.LANCZOS),
transforms.CenterCrop(image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), # → [-1, 1]
])
if hf_dataset is not None:
self.data = hf_dataset
self.mode = "hf"
elif source and os.path.isdir(source):
from glob import glob
self.files = []
for ext in ['*.png', '*.jpg', '*.jpeg', '*.webp', '*.bmp']:
self.files.extend(glob(os.path.join(source, '**', ext), recursive=True))
self.files.sort()
if max_samples:
self.files = self.files[:max_samples]
self.mode = "folder"
else:
from datasets import load_dataset
self.data = load_dataset(source, split=split)
if max_samples:
self.data = self.data.select(range(min(max_samples, len(self.data))))
self.mode = "hf"
def __len__(self):
return len(self.files) if self.mode == "folder" else len(self.data)
def __getitem__(self, idx):
if self.mode == "folder":
from PIL import Image
img = Image.open(self.files[idx]).convert("RGB")
else:
img = self.data[idx][self.image_column]
if not hasattr(img, 'convert'):
from PIL import Image as PILImage
img = PILImage.fromarray(img)
img = img.convert("RGB")
return self.transform(img)
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
"""Cosine annealing with linear warmup — standard for diffusion training."""
def lr_lambda(step):
if step < num_warmup_steps:
return float(step) / float(max(1, num_warmup_steps))
progress = float(step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)