| import torch |
| import tqdm |
| import k_diffusion.sampling |
| from modules import sd_samplers_common, sd_samplers_kdiffusion, sd_samplers |
| from tqdm.auto import trange, tqdm |
| from k_diffusion import utils |
| from k_diffusion.sampling import to_d, default_noise_sampler, get_ancestral_step |
| import math |
| from importlib import import_module |
|
|
| sampling = import_module("k_diffusion.sampling") |
| NAME = 'Euler_A_Test' |
| ALIAS = 'euler_a_test' |
|
|
|
|
| |
| |
|
|
| @torch.no_grad() |
| def sample_euler_ancestral_test(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): |
| """Ancestral sampling with Euler method steps.""" |
| extra_args = {} if extra_args is None else extra_args |
| noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler |
| s_in = x.new_ones([x.shape[0]]) |
| for i in trange(len(sigmas) - 1, disable=disable): |
| denoised = model(x, sigmas[i] * s_in, **extra_args) |
| sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) |
| if callback is not None: |
| callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) |
| d = to_d(x, sigmas[i], denoised) |
| |
| dt = sigma_down - sigmas[i] |
| x = x + d * dt |
| if sigmas[i + 1] > 0: |
| x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up |
| return x |
|
|
|
|
|
|
| |
| if not NAME in [x.name for x in sd_samplers.all_samplers]: |
| euler_max_samplers = [(NAME, sample_euler_ancestral_test, [ALIAS], {})] |
| samplers_data_euler_max_samplers = [ |
| sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: sd_samplers_kdiffusion.KDiffusionSampler(funcname, model), aliases, options) |
| for label, funcname, aliases, options in euler_max_samplers |
| if callable(funcname) or hasattr(k_diffusion.sampling, funcname) |
| ] |
| sd_samplers.all_samplers += samplers_data_euler_max_samplers |
| sd_samplers.all_samplers_map = {x.name: x for x in sd_samplers.all_samplers} |
| sd_samplers.set_samplers() |
|
|