File size: 9,319 Bytes
1b34e16 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 | import torch
import numpy as np
class Sampler:
def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float=0.0120):
#beta is a series of numbers that indicates the variance of the noise that we add with each of these steps
# the start and end values were a choice made by the authors
# will be using a linear scheduler, 1000 numbers between start and end
self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2
# alpha bar is the product of alpha going from 1 to T
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, 0)
self.one = torch.tensor(1.0)
self.generator = generator
self.num_training_steps = num_training_steps
self.timesteps = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy())
def set_inference_timesteps(self, num_inference_steps=50):
self.num_inference_steps = num_inference_steps
# 999, 998, 997, ... 0 = 1000 steps
# 999, 999-20, 999-40, ... 0 = 50 steps
step_ratio = self.num_training_steps // num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
self.timesteps = torch.from_numpy(timesteps)
def _get_previous_timestep(self, timestep:int) -> int:
prev_t = timestep - (self.num_training_steps // self.num_inference_steps)
return prev_t
def _get_variance(self, timestep: int) -> torch.Tensor:
prev_t = self._get_previous_timestep(timestep)
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
# Computed using formula (7) of the DDPM paper
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
variance = torch.clamp(variance, min=1e-20)
return variance
def set_strength(self, strength=1):
start_step = self.num_inference_steps - int(self.num_inference_steps * strength)
self.timesteps = self.timesteps[start_step:]
self.start_step = start_step
def ddpm_step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
t = timestep
prev_t = self._get_previous_timestep(t)
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
current_beta_t = 1 - current_alpha_t
# Compute the predicted original sample using formula (15) of the DDPM paper
pred_original_sample = (latents - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
# Compute the coefficient for pred_original_sample and current sample x_t
pred_original_sample_coeff = (alpha_prod_t_prev ** 0.5 * current_beta_t) / beta_prod_t
current_sample_coeff = current_alpha_t ** 0.5 * beta_prod_t_prev / beta_prod_t
# Compute the predicted previous sample mean
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents
variance = 0
if t > 0:
device = model_output.device
noise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype)
variance = (self._get_variance(t) ** 0.5) * noise
# N(0,1) --> N(mu, sigma)
# X = mu + sigma * Z where Z ~ N(0, 1)
pred_prev_sample = pred_prev_sample + variance
return pred_prev_sample
def ddim_step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor, eta=0.0):
t = timestep
prev_t = self._get_previous_timestep(t)
alpha_t = self.alphas_cumprod[t]
alpha_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else torch.tensor(1.0, device=latents.device, dtype=latents.dtype)
# Predicted original clean sample x_0
pred_original_sample = (latents - torch.sqrt(1 - alpha_t) * model_output) / torch.sqrt(alpha_t)
# Direction pointing to x_t
#dir_xt = torch.sqrt(1 - alpha_prev - (eta ** 2) * ((1 - alpha_prev) / (1 - alpha_t)) * (1 - alpha_t / alpha_prev)) * model_output
# Noise term
noise = torch.randn_like(latents) if eta > 0 else torch.zeros_like(latents)
sigma_t = eta * torch.sqrt((1 - alpha_prev) / (1 - alpha_t)) * torch.sqrt(1 - alpha_t / alpha_prev)
# Compute previous latent x_{t-1}
#prev_latent = torch.sqrt(alpha_prev) * pred_original_sample + dir_xt + sigma_t * noise
prev_latent = torch.sqrt(alpha_prev) * pred_original_sample + torch.sqrt(1 - alpha_prev - sigma_t ** 2) * model_output + sigma_t * noise
return prev_latent
def euler_ancestral_step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor, eta=1.0):
t = timestep
prev_t = self._get_previous_timestep(t)
# Convert alphas to sigmas (standard deviation of noise at each timestep)
alpha_t = self.alphas_cumprod[t]
alpha_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else torch.tensor(1.0, device=latents.device, dtype=latents.dtype)
sigma_t = torch.sqrt(1 - alpha_t)
sigma_prev = torch.sqrt(1 - alpha_prev)
# Predict x_0
x0_pred = (latents - sigma_t * model_output) / torch.sqrt(alpha_t)
# Euler drift step (toward next timestep)
dt = sigma_prev - sigma_t
x_drift = latents + dt * model_output
# Stochastic noise addition
if eta > 0.0:
noise = torch.randn_like(latents)
sigma = torch.sqrt(torch.clamp(eta * (sigma_prev**2 - sigma_t**2), min=1e-20))
x_drift += sigma * noise
return x_drift
def dpm_solver_pp_2m_step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
"""
One DPM-Solver++(2M) step with DDIM-style signature.
Args:
timestep: Current timestep index t.
latents: Latents at current timestep x_t.
model_output: Model prediction ε_θ(x_t, t).
Returns:
x_{t-1}: Estimated latent at previous timestep.
"""
t = self.timesteps[timestep]
prev_t = self.timesteps[timestep + 1] if timestep + 1 < len(self.timesteps) else 0.0 # t_{prev}
h = prev_t - t # Note: time goes backward
# Extract alpha and sigma for current and previous timesteps
alpha_t = self.alphas_cumprod[timestep] ** 0.5
alpha_prev = self.alphas_cumprod[timestep + 1] ** 0.5 if timestep + 1 < len(self.alphas_cumprod) else self.one
sigma_t = (1 - self.alphas_cumprod[timestep]) ** 0.5
sigma_prev = (1 - self.alphas_cumprod[timestep + 1]) ** 0.5 if timestep + 1 < len(self.alphas_cumprod) else self.zero
# Store previous model output if not already done
if not hasattr(self, "_prev_model_output"):
self._prev_model_output = model_output # Just initialize on first call
model_output_t = model_output
model_output_prev = self._prev_model_output
# Compute x0_t and x0_prev estimates
x0_t = (latents - sigma_t * model_output_t) / alpha_t
x0_prev = (latents - sigma_t * model_output_prev) / alpha_t
# 2nd-order multistep estimate
x0_hat = x0_t + 0.5 * h * (model_output_t - model_output_prev)
# Estimate x_{t-1}
x_prev = alpha_prev * x0_hat + sigma_prev * model_output_prev
# Update previous model output for next step
self._prev_model_output = model_output
return x_prev
def add_noise(self, original_samples: torch.FloatTensor, timesteps: torch.IntTensor) -> torch.FloatTensor:
#at what time we want to add the timestep
alpha_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alpha_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) #adds a new dimension with length one at a specific pos within tensors shape
sqrt_one_minus_alpha_prod = (1 - alpha_cumprod[timesteps]) ** 0.5 #standard deviation
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
# According to the euation (4) of the DDM paper
noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)
noisy_samples = (sqrt_alpha_prod * original_samples) + (sqrt_one_minus_alpha_prod) * noise
|