File size: 2,976 Bytes
5614582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
Sampling / inference for LiquidFlow.

Uses ODE integration (Euler or Heun's method) to solve:
    x_{t+dt} = x_t + v_θ(x_t, t) * dt

Starting from x_0 ~ N(0, I) and integrating to x_1 (clean image).
"""

import torch
import torch.nn as nn
from tqdm import tqdm


@torch.no_grad()
def euler_sample(model, shape, num_steps=50, device='cpu', class_label=None, cfg_scale=0.0):
    """
    Generate images using Euler method ODE integration.
    
    Flow matching: integrate dx/dt = v_θ(x_t, t) from t=0 to t=1
    x_0 ~ N(0, I) → x_1 = generated image
    """
    B = shape[0]
    x = torch.randn(shape, device=device)
    dt = 1.0 / num_steps
    
    for i in range(num_steps):
        t = torch.full((B,), i * dt, device=device)
        v = model(x, t, class_label)
        if cfg_scale > 0 and class_label is not None:
            v_uncond = model(x, t, None)
            v = v_uncond + cfg_scale * (v - v_uncond)
        x = x + v * dt
    
    return x


@torch.no_grad()
def heun_sample(model, shape, num_steps=25, device='cpu', class_label=None, cfg_scale=0.0):
    """
    Generate images using Heun's method (2nd order) ODE integration.
    More accurate than Euler. Each step costs 2 model evaluations.
    """
    B = shape[0]
    x = torch.randn(shape, device=device)
    dt = 1.0 / num_steps
    
    def get_v(x_in, t_in):
        v = model(x_in, t_in, class_label)
        if cfg_scale > 0 and class_label is not None:
            v_uncond = model(x_in, t_in, None)
            v = v_uncond + cfg_scale * (v - v_uncond)
        return v
    
    for i in range(num_steps):
        t = torch.full((B,), i * dt, device=device)
        t_next = torch.full((B,), min((i + 1) * dt, 1.0), device=device)
        k1 = get_v(x, t)
        x_hat = x + dt * k1
        if i < num_steps - 1:
            k2 = get_v(x_hat, t_next)
            x = x + dt * 0.5 * (k1 + k2)
        else:
            x = x + dt * k1
    
    return x


@torch.no_grad()
def generate_grid(model, num_images=16, num_steps=50, img_size=128,
                  device='cpu', class_label=None, cfg_scale=0.0, method='euler'):
    """Generate a grid of images. Returns (B, C, H, W) tensor in [0, 1]."""
    shape = (num_images, 3, img_size, img_size)
    if method == 'euler':
        images = euler_sample(model, shape, num_steps, device, class_label, cfg_scale)
    elif method == 'heun':
        images = heun_sample(model, shape, num_steps, device, class_label, cfg_scale)
    else:
        raise ValueError(f"Unknown method: {method}")
    return images.clamp(-1, 1) * 0.5 + 0.5


def make_grid_image(images, nrow=4, padding=2):
    """Arrange images into a grid. Returns a PIL Image."""
    from torchvision.utils import make_grid
    from PIL import Image
    import numpy as np
    grid = make_grid(images, nrow=nrow, padding=padding, normalize=False)
    grid = grid.permute(1, 2, 0).cpu().numpy()
    grid = (grid * 255).clip(0, 255).astype(np.uint8)
    return Image.fromarray(grid)