| | import math |
| |
|
| | import torch |
| | from torch import nn |
| |
|
| | from . import sampling, utils |
| |
|
| |
|
| | class VDenoiser(nn.Module): |
| | """A v-diffusion-pytorch model wrapper for k-diffusion.""" |
| |
|
| | def __init__(self, inner_model): |
| | super().__init__() |
| | self.inner_model = inner_model |
| | self.sigma_data = 1. |
| |
|
| | def get_scalings(self, sigma): |
| | c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) |
| | c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 |
| | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 |
| | return c_skip, c_out, c_in |
| |
|
| | def sigma_to_t(self, sigma): |
| | return sigma.atan() / math.pi * 2 |
| |
|
| | def t_to_sigma(self, t): |
| | return (t * math.pi / 2).tan() |
| |
|
| | def loss(self, input, noise, sigma, **kwargs): |
| | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] |
| | noised_input = input + noise * utils.append_dims(sigma, input.ndim) |
| | model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) |
| | target = (input - c_skip * noised_input) / c_out |
| | return (model_output - target).pow(2).flatten(1).mean(1) |
| |
|
| | def forward(self, input, sigma, **kwargs): |
| | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] |
| | return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip |
| |
|
| |
|
| | class DiscreteSchedule(nn.Module): |
| | """A mapping between continuous noise levels (sigmas) and a list of discrete noise |
| | levels.""" |
| |
|
| | def __init__(self, sigmas, quantize): |
| | super().__init__() |
| | self.register_buffer('sigmas', sigmas) |
| | self.register_buffer('log_sigmas', sigmas.log()) |
| | self.quantize = quantize |
| |
|
| | @property |
| | def sigma_min(self): |
| | return self.sigmas[0] |
| |
|
| | @property |
| | def sigma_max(self): |
| | return self.sigmas[-1] |
| |
|
| | def get_sigmas(self, n=None): |
| | if n is None: |
| | return sampling.append_zero(self.sigmas.flip(0)) |
| | t_max = len(self.sigmas) - 1 |
| | t = torch.linspace(t_max, 0, n, device=self.sigmas.device) |
| | return sampling.append_zero(self.t_to_sigma(t)) |
| |
|
| | def sigma_to_discrete_timestep(self, sigma): |
| | log_sigma = sigma.log() |
| | dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] |
| | return dists.abs().argmin(dim=0).view(sigma.shape) |
| |
|
| | def sigma_to_t(self, sigma, quantize=None): |
| | quantize = self.quantize if quantize is None else quantize |
| | if quantize: |
| | return self.sigma_to_discrete_timestep(sigma) |
| | log_sigma = sigma.log() |
| | dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] |
| | low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) |
| | high_idx = low_idx + 1 |
| | low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx] |
| | w = (low - log_sigma) / (low - high) |
| | w = w.clamp(0, 1) |
| | t = (1 - w) * low_idx + w * high_idx |
| | return t.view(sigma.shape) |
| |
|
| | def t_to_sigma(self, t): |
| | t = t.float() |
| | low_idx = t.floor().long() |
| | high_idx = t.ceil().long() |
| | w = t-low_idx if t.device.type == 'mps' else t.frac() |
| | log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] |
| | return log_sigma.exp() |
| |
|
| | def predict_eps_discrete_timestep(self, input, t, **kwargs): |
| | if t.dtype != torch.int64 and t.dtype != torch.int32: |
| | t = t.round() |
| | sigma = self.t_to_sigma(t) |
| | input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5) |
| | return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim) |
| |
|
| | class DiscreteEpsDDPMDenoiser(DiscreteSchedule): |
| | """A wrapper for discrete schedule DDPM models that output eps (the predicted |
| | noise).""" |
| |
|
| | def __init__(self, model, alphas_cumprod, quantize): |
| | super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) |
| | self.inner_model = model |
| | self.sigma_data = 1. |
| |
|
| | def get_scalings(self, sigma): |
| | c_out = -sigma |
| | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 |
| | return c_out, c_in |
| |
|
| | def get_eps(self, *args, **kwargs): |
| | return self.inner_model(*args, **kwargs) |
| |
|
| | def loss(self, input, noise, sigma, **kwargs): |
| | c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] |
| | noised_input = input + noise * utils.append_dims(sigma, input.ndim) |
| | eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) |
| | return (eps - noise).pow(2).flatten(1).mean(1) |
| |
|
| | def forward(self, input, sigma, **kwargs): |
| | c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] |
| | eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs) |
| | return input + eps * c_out |
| |
|
| |
|
| | class OpenAIDenoiser(DiscreteEpsDDPMDenoiser): |
| | """A wrapper for OpenAI diffusion models.""" |
| |
|
| | def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'): |
| | alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32) |
| | super().__init__(model, alphas_cumprod, quantize=quantize) |
| | self.has_learned_sigmas = has_learned_sigmas |
| |
|
| | def get_eps(self, *args, **kwargs): |
| | model_output = self.inner_model(*args, **kwargs) |
| | if self.has_learned_sigmas: |
| | return model_output.chunk(2, dim=1)[0] |
| | return model_output |
| |
|
| |
|
| | class CompVisDenoiser(DiscreteEpsDDPMDenoiser): |
| | """A wrapper for CompVis diffusion models.""" |
| |
|
| | def __init__(self, model, quantize=False, device='cpu'): |
| | super().__init__(model, model.alphas_cumprod, quantize=quantize) |
| |
|
| | def get_eps(self, *args, **kwargs): |
| | return self.inner_model.apply_model(*args, **kwargs) |
| |
|
| |
|
| | class DiscreteVDDPMDenoiser(DiscreteSchedule): |
| | """A wrapper for discrete schedule DDPM models that output v.""" |
| |
|
| | def __init__(self, model, alphas_cumprod, quantize): |
| | super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) |
| | self.inner_model = model |
| | self.sigma_data = 1. |
| |
|
| | def get_scalings(self, sigma): |
| | c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) |
| | c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 |
| | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 |
| | return c_skip, c_out, c_in |
| |
|
| | def get_v(self, *args, **kwargs): |
| | return self.inner_model(*args, **kwargs) |
| |
|
| | def loss(self, input, noise, sigma, **kwargs): |
| | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] |
| | noised_input = input + noise * utils.append_dims(sigma, input.ndim) |
| | model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) |
| | target = (input - c_skip * noised_input) / c_out |
| | return (model_output - target).pow(2).flatten(1).mean(1) |
| |
|
| | def forward(self, input, sigma, **kwargs): |
| | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] |
| | return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip |
| |
|
| |
|
| | class CompVisVDenoiser(DiscreteVDDPMDenoiser): |
| | """A wrapper for CompVis diffusion models that output v.""" |
| |
|
| | def __init__(self, model, quantize=False, device='cpu'): |
| | super().__init__(model, model.alphas_cumprod, quantize=quantize) |
| |
|
| | def get_v(self, x, t, cond, **kwargs): |
| | return self.inner_model.apply_model(x, t, cond) |
| |
|