flying101's picture
Upload 31 files
1b34e16 verified
Raw
History Blame Contribute Delete
9.32 kB
import torch
import numpy as np
class Sampler:
def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float=0.0120):
#beta is a series of numbers that indicates the variance of the noise that we add with each of these steps
# the start and end values were a choice made by the authors
# will be using a linear scheduler, 1000 numbers between start and end
self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2
# alpha bar is the product of alpha going from 1 to T
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, 0)
self.one = torch.tensor(1.0)
self.generator = generator
self.num_training_steps = num_training_steps
self.timesteps = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy())
def set_inference_timesteps(self, num_inference_steps=50):
self.num_inference_steps = num_inference_steps
# 999, 998, 997, ... 0 = 1000 steps
# 999, 999-20, 999-40, ... 0 = 50 steps
step_ratio = self.num_training_steps // num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
self.timesteps = torch.from_numpy(timesteps)
def _get_previous_timestep(self, timestep:int) -> int:
prev_t = timestep - (self.num_training_steps // self.num_inference_steps)
return prev_t
def _get_variance(self, timestep: int) -> torch.Tensor:
prev_t = self._get_previous_timestep(timestep)
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
# Computed using formula (7) of the DDPM paper
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
variance = torch.clamp(variance, min=1e-20)
return variance
def set_strength(self, strength=1):
start_step = self.num_inference_steps - int(self.num_inference_steps * strength)
self.timesteps = self.timesteps[start_step:]
self.start_step = start_step
def ddpm_step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
t = timestep
prev_t = self._get_previous_timestep(t)
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
current_beta_t = 1 - current_alpha_t
# Compute the predicted original sample using formula (15) of the DDPM paper
pred_original_sample = (latents - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
# Compute the coefficient for pred_original_sample and current sample x_t
pred_original_sample_coeff = (alpha_prod_t_prev ** 0.5 * current_beta_t) / beta_prod_t
current_sample_coeff = current_alpha_t ** 0.5 * beta_prod_t_prev / beta_prod_t
# Compute the predicted previous sample mean
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents
variance = 0
if t > 0:
device = model_output.device
noise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype)
variance = (self._get_variance(t) ** 0.5) * noise
# N(0,1) --> N(mu, sigma)
# X = mu + sigma * Z where Z ~ N(0, 1)
pred_prev_sample = pred_prev_sample + variance
return pred_prev_sample
def ddim_step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor, eta=0.0):
t = timestep
prev_t = self._get_previous_timestep(t)
alpha_t = self.alphas_cumprod[t]
alpha_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else torch.tensor(1.0, device=latents.device, dtype=latents.dtype)
# Predicted original clean sample x_0
pred_original_sample = (latents - torch.sqrt(1 - alpha_t) * model_output) / torch.sqrt(alpha_t)
# Direction pointing to x_t
#dir_xt = torch.sqrt(1 - alpha_prev - (eta ** 2) * ((1 - alpha_prev) / (1 - alpha_t)) * (1 - alpha_t / alpha_prev)) * model_output
# Noise term
noise = torch.randn_like(latents) if eta > 0 else torch.zeros_like(latents)
sigma_t = eta * torch.sqrt((1 - alpha_prev) / (1 - alpha_t)) * torch.sqrt(1 - alpha_t / alpha_prev)
# Compute previous latent x_{t-1}
#prev_latent = torch.sqrt(alpha_prev) * pred_original_sample + dir_xt + sigma_t * noise
prev_latent = torch.sqrt(alpha_prev) * pred_original_sample + torch.sqrt(1 - alpha_prev - sigma_t ** 2) * model_output + sigma_t * noise
return prev_latent
def euler_ancestral_step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor, eta=1.0):
t = timestep
prev_t = self._get_previous_timestep(t)
# Convert alphas to sigmas (standard deviation of noise at each timestep)
alpha_t = self.alphas_cumprod[t]
alpha_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else torch.tensor(1.0, device=latents.device, dtype=latents.dtype)
sigma_t = torch.sqrt(1 - alpha_t)
sigma_prev = torch.sqrt(1 - alpha_prev)
# Predict x_0
x0_pred = (latents - sigma_t * model_output) / torch.sqrt(alpha_t)
# Euler drift step (toward next timestep)
dt = sigma_prev - sigma_t
x_drift = latents + dt * model_output
# Stochastic noise addition
if eta > 0.0:
noise = torch.randn_like(latents)
sigma = torch.sqrt(torch.clamp(eta * (sigma_prev**2 - sigma_t**2), min=1e-20))
x_drift += sigma * noise
return x_drift
def dpm_solver_pp_2m_step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
"""
One DPM-Solver++(2M) step with DDIM-style signature.
Args:
timestep: Current timestep index t.
latents: Latents at current timestep x_t.
model_output: Model prediction ε_θ(x_t, t).
Returns:
x_{t-1}: Estimated latent at previous timestep.
"""
t = self.timesteps[timestep]
prev_t = self.timesteps[timestep + 1] if timestep + 1 < len(self.timesteps) else 0.0 # t_{prev}
h = prev_t - t # Note: time goes backward
# Extract alpha and sigma for current and previous timesteps
alpha_t = self.alphas_cumprod[timestep] ** 0.5
alpha_prev = self.alphas_cumprod[timestep + 1] ** 0.5 if timestep + 1 < len(self.alphas_cumprod) else self.one
sigma_t = (1 - self.alphas_cumprod[timestep]) ** 0.5
sigma_prev = (1 - self.alphas_cumprod[timestep + 1]) ** 0.5 if timestep + 1 < len(self.alphas_cumprod) else self.zero
# Store previous model output if not already done
if not hasattr(self, "_prev_model_output"):
self._prev_model_output = model_output # Just initialize on first call
model_output_t = model_output
model_output_prev = self._prev_model_output
# Compute x0_t and x0_prev estimates
x0_t = (latents - sigma_t * model_output_t) / alpha_t
x0_prev = (latents - sigma_t * model_output_prev) / alpha_t
# 2nd-order multistep estimate
x0_hat = x0_t + 0.5 * h * (model_output_t - model_output_prev)
# Estimate x_{t-1}
x_prev = alpha_prev * x0_hat + sigma_prev * model_output_prev
# Update previous model output for next step
self._prev_model_output = model_output
return x_prev
def add_noise(self, original_samples: torch.FloatTensor, timesteps: torch.IntTensor) -> torch.FloatTensor:
#at what time we want to add the timestep
alpha_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alpha_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) #adds a new dimension with length one at a specific pos within tensors shape
sqrt_one_minus_alpha_prod = (1 - alpha_cumprod[timesteps]) ** 0.5 #standard deviation
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
# According to the euation (4) of the DDM paper
noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)
noisy_samples = (sqrt_alpha_prod * original_samples) + (sqrt_one_minus_alpha_prod) * noise