LiquidFlow / liquidflow /sampling.py
krystv's picture
Add sampling.py — Euler and Heun ODE samplers
5614582 verified
"""
Sampling / inference for LiquidFlow.
Uses ODE integration (Euler or Heun's method) to solve:
x_{t+dt} = x_t + v_θ(x_t, t) * dt
Starting from x_0 ~ N(0, I) and integrating to x_1 (clean image).
"""
import torch
import torch.nn as nn
from tqdm import tqdm
@torch.no_grad()
def euler_sample(model, shape, num_steps=50, device='cpu', class_label=None, cfg_scale=0.0):
"""
Generate images using Euler method ODE integration.
Flow matching: integrate dx/dt = v_θ(x_t, t) from t=0 to t=1
x_0 ~ N(0, I) → x_1 = generated image
"""
B = shape[0]
x = torch.randn(shape, device=device)
dt = 1.0 / num_steps
for i in range(num_steps):
t = torch.full((B,), i * dt, device=device)
v = model(x, t, class_label)
if cfg_scale > 0 and class_label is not None:
v_uncond = model(x, t, None)
v = v_uncond + cfg_scale * (v - v_uncond)
x = x + v * dt
return x
@torch.no_grad()
def heun_sample(model, shape, num_steps=25, device='cpu', class_label=None, cfg_scale=0.0):
"""
Generate images using Heun's method (2nd order) ODE integration.
More accurate than Euler. Each step costs 2 model evaluations.
"""
B = shape[0]
x = torch.randn(shape, device=device)
dt = 1.0 / num_steps
def get_v(x_in, t_in):
v = model(x_in, t_in, class_label)
if cfg_scale > 0 and class_label is not None:
v_uncond = model(x_in, t_in, None)
v = v_uncond + cfg_scale * (v - v_uncond)
return v
for i in range(num_steps):
t = torch.full((B,), i * dt, device=device)
t_next = torch.full((B,), min((i + 1) * dt, 1.0), device=device)
k1 = get_v(x, t)
x_hat = x + dt * k1
if i < num_steps - 1:
k2 = get_v(x_hat, t_next)
x = x + dt * 0.5 * (k1 + k2)
else:
x = x + dt * k1
return x
@torch.no_grad()
def generate_grid(model, num_images=16, num_steps=50, img_size=128,
device='cpu', class_label=None, cfg_scale=0.0, method='euler'):
"""Generate a grid of images. Returns (B, C, H, W) tensor in [0, 1]."""
shape = (num_images, 3, img_size, img_size)
if method == 'euler':
images = euler_sample(model, shape, num_steps, device, class_label, cfg_scale)
elif method == 'heun':
images = heun_sample(model, shape, num_steps, device, class_label, cfg_scale)
else:
raise ValueError(f"Unknown method: {method}")
return images.clamp(-1, 1) * 0.5 + 0.5
def make_grid_image(images, nrow=4, padding=2):
"""Arrange images into a grid. Returns a PIL Image."""
from torchvision.utils import make_grid
from PIL import Image
import numpy as np
grid = make_grid(images, nrow=nrow, padding=padding, normalize=False)
grid = grid.permute(1, 2, 0).cpu().numpy()
grid = (grid * 255).clip(0, 255).astype(np.uint8)
return Image.fromarray(grid)