| """ |
| Sampling / inference for LiquidFlow. |
| |
| Uses ODE integration (Euler or Heun's method) to solve: |
| x_{t+dt} = x_t + v_θ(x_t, t) * dt |
| |
| Starting from x_0 ~ N(0, I) and integrating to x_1 (clean image). |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from tqdm import tqdm |
|
|
|
|
| @torch.no_grad() |
| def euler_sample(model, shape, num_steps=50, device='cpu', class_label=None, cfg_scale=0.0): |
| """ |
| Generate images using Euler method ODE integration. |
| |
| Flow matching: integrate dx/dt = v_θ(x_t, t) from t=0 to t=1 |
| x_0 ~ N(0, I) → x_1 = generated image |
| """ |
| B = shape[0] |
| x = torch.randn(shape, device=device) |
| dt = 1.0 / num_steps |
| |
| for i in range(num_steps): |
| t = torch.full((B,), i * dt, device=device) |
| v = model(x, t, class_label) |
| if cfg_scale > 0 and class_label is not None: |
| v_uncond = model(x, t, None) |
| v = v_uncond + cfg_scale * (v - v_uncond) |
| x = x + v * dt |
| |
| return x |
|
|
|
|
| @torch.no_grad() |
| def heun_sample(model, shape, num_steps=25, device='cpu', class_label=None, cfg_scale=0.0): |
| """ |
| Generate images using Heun's method (2nd order) ODE integration. |
| More accurate than Euler. Each step costs 2 model evaluations. |
| """ |
| B = shape[0] |
| x = torch.randn(shape, device=device) |
| dt = 1.0 / num_steps |
| |
| def get_v(x_in, t_in): |
| v = model(x_in, t_in, class_label) |
| if cfg_scale > 0 and class_label is not None: |
| v_uncond = model(x_in, t_in, None) |
| v = v_uncond + cfg_scale * (v - v_uncond) |
| return v |
| |
| for i in range(num_steps): |
| t = torch.full((B,), i * dt, device=device) |
| t_next = torch.full((B,), min((i + 1) * dt, 1.0), device=device) |
| k1 = get_v(x, t) |
| x_hat = x + dt * k1 |
| if i < num_steps - 1: |
| k2 = get_v(x_hat, t_next) |
| x = x + dt * 0.5 * (k1 + k2) |
| else: |
| x = x + dt * k1 |
| |
| return x |
|
|
|
|
| @torch.no_grad() |
| def generate_grid(model, num_images=16, num_steps=50, img_size=128, |
| device='cpu', class_label=None, cfg_scale=0.0, method='euler'): |
| """Generate a grid of images. Returns (B, C, H, W) tensor in [0, 1].""" |
| shape = (num_images, 3, img_size, img_size) |
| if method == 'euler': |
| images = euler_sample(model, shape, num_steps, device, class_label, cfg_scale) |
| elif method == 'heun': |
| images = heun_sample(model, shape, num_steps, device, class_label, cfg_scale) |
| else: |
| raise ValueError(f"Unknown method: {method}") |
| return images.clamp(-1, 1) * 0.5 + 0.5 |
|
|
|
|
| def make_grid_image(images, nrow=4, padding=2): |
| """Arrange images into a grid. Returns a PIL Image.""" |
| from torchvision.utils import make_grid |
| from PIL import Image |
| import numpy as np |
| grid = make_grid(images, nrow=nrow, padding=padding, normalize=False) |
| grid = grid.permute(1, 2, 0).cpu().numpy() |
| grid = (grid * 255).clip(0, 255).astype(np.uint8) |
| return Image.fromarray(grid) |