""" Sampling / inference for LiquidFlow. Uses ODE integration (Euler or Heun's method) to solve: x_{t+dt} = x_t + v_θ(x_t, t) * dt Starting from x_0 ~ N(0, I) and integrating to x_1 (clean image). """ import torch import torch.nn as nn from tqdm import tqdm @torch.no_grad() def euler_sample(model, shape, num_steps=50, device='cpu', class_label=None, cfg_scale=0.0): """ Generate images using Euler method ODE integration. Flow matching: integrate dx/dt = v_θ(x_t, t) from t=0 to t=1 x_0 ~ N(0, I) → x_1 = generated image """ B = shape[0] x = torch.randn(shape, device=device) dt = 1.0 / num_steps for i in range(num_steps): t = torch.full((B,), i * dt, device=device) v = model(x, t, class_label) if cfg_scale > 0 and class_label is not None: v_uncond = model(x, t, None) v = v_uncond + cfg_scale * (v - v_uncond) x = x + v * dt return x @torch.no_grad() def heun_sample(model, shape, num_steps=25, device='cpu', class_label=None, cfg_scale=0.0): """ Generate images using Heun's method (2nd order) ODE integration. More accurate than Euler. Each step costs 2 model evaluations. """ B = shape[0] x = torch.randn(shape, device=device) dt = 1.0 / num_steps def get_v(x_in, t_in): v = model(x_in, t_in, class_label) if cfg_scale > 0 and class_label is not None: v_uncond = model(x_in, t_in, None) v = v_uncond + cfg_scale * (v - v_uncond) return v for i in range(num_steps): t = torch.full((B,), i * dt, device=device) t_next = torch.full((B,), min((i + 1) * dt, 1.0), device=device) k1 = get_v(x, t) x_hat = x + dt * k1 if i < num_steps - 1: k2 = get_v(x_hat, t_next) x = x + dt * 0.5 * (k1 + k2) else: x = x + dt * k1 return x @torch.no_grad() def generate_grid(model, num_images=16, num_steps=50, img_size=128, device='cpu', class_label=None, cfg_scale=0.0, method='euler'): """Generate a grid of images. Returns (B, C, H, W) tensor in [0, 1].""" shape = (num_images, 3, img_size, img_size) if method == 'euler': images = euler_sample(model, shape, num_steps, device, class_label, cfg_scale) elif method == 'heun': images = heun_sample(model, shape, num_steps, device, class_label, cfg_scale) else: raise ValueError(f"Unknown method: {method}") return images.clamp(-1, 1) * 0.5 + 0.5 def make_grid_image(images, nrow=4, padding=2): """Arrange images into a grid. Returns a PIL Image.""" from torchvision.utils import make_grid from PIL import Image import numpy as np grid = make_grid(images, nrow=nrow, padding=padding, normalize=False) grid = grid.permute(1, 2, 0).cpu().numpy() grid = (grid * 255).clip(0, 255).astype(np.uint8) return Image.fromarray(grid)