File size: 3,571 Bytes
560fb53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import torch
def time_shift_sana(t: torch.Tensor, flow_shift: float = 1., sigma: float = 1.):
return (1 / flow_shift) / ( (1 / flow_shift) + (1 / t - 1) ** sigma)
def get_score_from_velocity(velocity, x, t):
alpha_t, d_alpha_t = t, 1
sigma_t, d_sigma_t = 1 - t, -1
mean = x
reverse_alpha_ratio = alpha_t / d_alpha_t
var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
score = (reverse_alpha_ratio * velocity - mean) / var
return score
def get_velocity_from_cfg(velocity, cfg, cfg_mult):
if cfg_mult == 2:
cond_v, uncond_v = torch.chunk(velocity, 2, dim=0)
velocity = uncond_v + cfg * (cond_v - uncond_v)
return velocity
@torch.compile()
def euler_step(x, v, dt: float, cfg: float, cfg_mult: int):
with torch.amp.autocast("cuda", enabled=False):
v = v.to(torch.float32)
v = get_velocity_from_cfg(v, cfg, cfg_mult)
x = x + v * dt
return x
@torch.compile()
def euler_maruyama_step(x, v, t, dt: float, cfg: float, cfg_mult: int):
with torch.amp.autocast("cuda", enabled=False):
v = v.to(torch.float32)
v = get_velocity_from_cfg(v, cfg, cfg_mult)
score = get_score_from_velocity(v, x, t)
drift = v + (1 - t) * score
noise_scale = (2.0 * (1.0 - t) * dt) ** 0.5
x = x + drift * dt + noise_scale * torch.randn_like(x)
return x
def euler_maruyama(
input_dim,
forward_fn,
c: torch.Tensor,
cfg: float = 1.0,
num_sampling_steps: int = 20,
last_step_size: float = 0.05,
time_shift: float = 1.,
):
cfg_mult = 1
if cfg > 1.0:
cfg_mult += 1
x_shape = list(c.shape)
x_shape[0] = x_shape[0] // cfg_mult
x_shape[-1] = input_dim
x = torch.randn(x_shape, device=c.device)
# an = (1.0 - last_step_size) / num_sampling_steps
t_all = torch.linspace(0, 1-last_step_size, num_sampling_steps+1, device=c.device, dtype=torch.float32)
t_all = time_shift_sana(t_all, time_shift)
dt = t_all[1:] - t_all[:-1]
t = torch.tensor(
0.0, device=c.device, dtype=torch.float32
) # use tensor to avoid compile warning
t_batch = torch.zeros(c.shape[0], device=c.device)
for i in range(num_sampling_steps):
t_batch[:] = t
combined = torch.cat([x] * cfg_mult, dim=0)
output = forward_fn(
combined,
t_batch,
c,
)
v = (output - combined) / (1 - t_batch.view(-1, 1)).clamp_min(0.05)
x = euler_maruyama_step(x, v, t, dt[i], cfg, cfg_mult)
t += dt[i]
combined = torch.cat([x] * cfg_mult, dim=0)
t_batch[:] = 1 - last_step_size
output = forward_fn(
combined,
t_batch,
c,
)
v = (output - combined) / (1 - t_batch.view(-1, 1)).clamp_min(0.05)
x = euler_step(x, v, last_step_size, cfg, cfg_mult)
return torch.cat([x] * cfg_mult, dim=0)
def euler(
input_dim,
forward_fn,
c,
cfg: float = 1.0,
num_sampling_steps: int = 50,
):
cfg_mult = 1
if cfg > 1.0:
cfg_mult = 2
x_shape = list(c.shape)
x_shape[0] = x_shape[0] // cfg_mult
x_shape[-1] = input_dim
x = torch.randn(x_shape, device=c.device)
dt = 1.0 / num_sampling_steps
t = 0
t_batch = torch.zeros(c.shape[0], device=c.device)
for _ in range(num_sampling_steps):
t_batch[:] = t
combined = torch.cat([x] * cfg_mult, dim=0)
v = forward_fn(combined, t_batch, c)
x = euler_step(x, v, dt, cfg, cfg_mult)
t += dt
return torch.cat([x] * cfg_mult, dim=0)
|