| import torch |
| import numpy as np |
|
|
| class Sampler: |
|
|
| def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float=0.0120): |
| |
| |
| |
|
|
| self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2 |
|
|
| |
| self.alphas = 1.0 - self.betas |
| self.alphas_cumprod = torch.cumprod(self.alphas, 0) |
| self.one = torch.tensor(1.0) |
|
|
| self.generator = generator |
| self.num_training_steps = num_training_steps |
| self.timesteps = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy()) |
| |
| def set_inference_timesteps(self, num_inference_steps=50): |
| self.num_inference_steps = num_inference_steps |
| |
| |
| step_ratio = self.num_training_steps // num_inference_steps |
| timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) |
| self.timesteps = torch.from_numpy(timesteps) |
| |
| def _get_previous_timestep(self, timestep:int) -> int: |
| prev_t = timestep - (self.num_training_steps // self.num_inference_steps) |
| return prev_t |
| |
| def _get_variance(self, timestep: int) -> torch.Tensor: |
| prev_t = self._get_previous_timestep(timestep) |
|
|
| alpha_prod_t = self.alphas_cumprod[timestep] |
| alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one |
| current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev |
|
|
| |
| variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t |
| variance = torch.clamp(variance, min=1e-20) |
|
|
| return variance |
| |
| def set_strength(self, strength=1): |
| start_step = self.num_inference_steps - int(self.num_inference_steps * strength) |
| self.timesteps = self.timesteps[start_step:] |
| self.start_step = start_step |
|
|
| def ddpm_step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor): |
| t = timestep |
| prev_t = self._get_previous_timestep(t) |
|
|
| alpha_prod_t = self.alphas_cumprod[t] |
| alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one |
| beta_prod_t = 1 - alpha_prod_t |
| beta_prod_t_prev = 1 - alpha_prod_t_prev |
| current_alpha_t = alpha_prod_t / alpha_prod_t_prev |
| current_beta_t = 1 - current_alpha_t |
|
|
| |
| pred_original_sample = (latents - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 |
|
|
| |
| pred_original_sample_coeff = (alpha_prod_t_prev ** 0.5 * current_beta_t) / beta_prod_t |
| current_sample_coeff = current_alpha_t ** 0.5 * beta_prod_t_prev / beta_prod_t |
|
|
| |
| pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents |
|
|
| variance = 0 |
| if t > 0: |
| device = model_output.device |
| noise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype) |
| variance = (self._get_variance(t) ** 0.5) * noise |
| |
| |
| |
| pred_prev_sample = pred_prev_sample + variance |
| return pred_prev_sample |
| |
| def ddim_step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor, eta=0.0): |
| t = timestep |
| prev_t = self._get_previous_timestep(t) |
| |
| alpha_t = self.alphas_cumprod[t] |
| alpha_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else torch.tensor(1.0, device=latents.device, dtype=latents.dtype) |
| |
| |
| pred_original_sample = (latents - torch.sqrt(1 - alpha_t) * model_output) / torch.sqrt(alpha_t) |
| |
| |
| |
| |
| |
| noise = torch.randn_like(latents) if eta > 0 else torch.zeros_like(latents) |
| |
| sigma_t = eta * torch.sqrt((1 - alpha_prev) / (1 - alpha_t)) * torch.sqrt(1 - alpha_t / alpha_prev) |
| |
| |
| |
| prev_latent = torch.sqrt(alpha_prev) * pred_original_sample + torch.sqrt(1 - alpha_prev - sigma_t ** 2) * model_output + sigma_t * noise |
|
|
| |
| return prev_latent |
| |
| def euler_ancestral_step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor, eta=1.0): |
| t = timestep |
| prev_t = self._get_previous_timestep(t) |
| |
| |
| alpha_t = self.alphas_cumprod[t] |
| alpha_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else torch.tensor(1.0, device=latents.device, dtype=latents.dtype) |
|
|
| sigma_t = torch.sqrt(1 - alpha_t) |
| sigma_prev = torch.sqrt(1 - alpha_prev) |
| |
| |
| x0_pred = (latents - sigma_t * model_output) / torch.sqrt(alpha_t) |
| |
| |
| dt = sigma_prev - sigma_t |
| x_drift = latents + dt * model_output |
|
|
| |
| if eta > 0.0: |
| noise = torch.randn_like(latents) |
| sigma = torch.sqrt(torch.clamp(eta * (sigma_prev**2 - sigma_t**2), min=1e-20)) |
| x_drift += sigma * noise |
|
|
| return x_drift |
| |
| def dpm_solver_pp_2m_step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor): |
| """ |
| One DPM-Solver++(2M) step with DDIM-style signature. |
| |
| Args: |
| timestep: Current timestep index t. |
| latents: Latents at current timestep x_t. |
| model_output: Model prediction ε_θ(x_t, t). |
| |
| Returns: |
| x_{t-1}: Estimated latent at previous timestep. |
| """ |
| t = self.timesteps[timestep] |
| prev_t = self.timesteps[timestep + 1] if timestep + 1 < len(self.timesteps) else 0.0 |
|
|
| h = prev_t - t |
|
|
| |
| alpha_t = self.alphas_cumprod[timestep] ** 0.5 |
| alpha_prev = self.alphas_cumprod[timestep + 1] ** 0.5 if timestep + 1 < len(self.alphas_cumprod) else self.one |
| sigma_t = (1 - self.alphas_cumprod[timestep]) ** 0.5 |
| sigma_prev = (1 - self.alphas_cumprod[timestep + 1]) ** 0.5 if timestep + 1 < len(self.alphas_cumprod) else self.zero |
|
|
| |
| if not hasattr(self, "_prev_model_output"): |
| self._prev_model_output = model_output |
|
|
| model_output_t = model_output |
| model_output_prev = self._prev_model_output |
|
|
| |
| x0_t = (latents - sigma_t * model_output_t) / alpha_t |
| x0_prev = (latents - sigma_t * model_output_prev) / alpha_t |
|
|
| |
| x0_hat = x0_t + 0.5 * h * (model_output_t - model_output_prev) |
|
|
| |
| x_prev = alpha_prev * x0_hat + sigma_prev * model_output_prev |
|
|
| |
| self._prev_model_output = model_output |
|
|
| return x_prev |
|
|
|
|
| def add_noise(self, original_samples: torch.FloatTensor, timesteps: torch.IntTensor) -> torch.FloatTensor: |
| |
| alpha_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) |
| timesteps = timesteps.to(original_samples.device) |
| |
| sqrt_alpha_prod = alpha_cumprod[timesteps] ** 0.5 |
| sqrt_alpha_prod = sqrt_alpha_prod.flatten() |
| while len(sqrt_alpha_prod.shape) < len(original_samples.shape): |
| sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) |
| sqrt_one_minus_alpha_prod = (1 - alpha_cumprod[timesteps]) ** 0.5 |
| sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() |
| while len(sqrt_one_minus_alpha_prod) < len(original_samples.shape): |
| sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) |
| |
| |
| noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype) |
| noisy_samples = (sqrt_alpha_prod * original_samples) + (sqrt_one_minus_alpha_prod) * noise |
| |
|
|