import dataclasses import torch import k_diffusion import inspect from types import SimpleNamespace from refnet.util import default from .scheduler import schedulers, schedulers_map from .denoiser import CFGDenoiser defaults = SimpleNamespace(**{ "eta_ddim": 0.0, "eta_ancestral": 1.0, "ddim_discretize": "uniform", "s_churn": 0.0, "s_tmin": 0.0, "s_noise": 1.0, "k_sched_type": "Automatic", "sigma_min": 0.0, "sigma_max": 0.0, "rho": 0.0, "eta_noise_seed_delta": 0, "always_discard_next_to_last_sigma": False, }) @dataclasses.dataclass class Sampler: label: str funcname: str aliases: any options: dict samplers_k_diffusion = [ Sampler('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {'scheduler': 'karras'}), Sampler('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), Sampler('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde'], {'scheduler': 'exponential', "brownian_noise": True}), Sampler('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}), Sampler('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}), Sampler('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}), Sampler('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}), Sampler('Euler', 'sample_euler', ['k_euler'], {}), Sampler('LMS', 'sample_lms', ['k_lms'], {}), Sampler('Heun', 'sample_heun', ['k_heun'], {"second_order": True}), Sampler('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "second_order": True}), Sampler('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}), Sampler('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}), Sampler('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}) ] sampler_extra_params = { 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_dpm_fast': ['s_noise'], 'sample_dpm_2_ancestral': ['s_noise'], 'sample_dpmpp_2s_ancestral': ['s_noise'], 'sample_dpmpp_sde': ['s_noise'], 'sample_dpmpp_2m_sde': ['s_noise'], 'sample_dpmpp_3m_sde': ['s_noise'], } def kdiffusion_sampler_list(): return [k.label for k in samplers_k_diffusion] k_diffusion_samplers_map = {x.label: x for x in samplers_k_diffusion} k_diffusion_scheduler = {x.name: x.function for x in schedulers} def exists(v): return v is not None class KDiffusionSampler: def __init__(self, sampler, scheduler, sd, device): # k_diffusion_samplers_map[] self.config = k_diffusion_samplers_map[sampler] funcname = self.config.funcname self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, funcname) self.scheduler_name = scheduler self.sd = CFGDenoiser(sd, device) self.model_wrap = self.sd.model_wrap self.device = device self.s_min_uncond = None self.s_churn = 0.0 self.s_tmin = 0.0 self.s_tmax = float('inf') self.s_noise = 1.0 self.eta_option_field = 'eta_ancestral' self.eta_infotext_field = 'Eta' self.eta_default = 1.0 self.eta = None self.extra_params = [] if exists(sd.sigma_max) and exists(sd.sigma_min): self.model_wrap.sigmas[-1] = sd.sigma_max self.model_wrap.sigmas[0] = sd.sigma_min def initialize(self): self.eta = getattr(defaults, self.eta_option_field, 0.0) extra_params_kwargs = {} for param_name in self.extra_params: if param_name in inspect.signature(self.func).parameters: extra_params_kwargs[param_name] = getattr(self, param_name) if 'eta' in inspect.signature(self.func).parameters: extra_params_kwargs['eta'] = self.eta if len(self.extra_params) > 0: s_churn = getattr(defaults, 's_churn', self.s_churn) s_tmin = getattr(defaults, 's_tmin', self.s_tmin) s_tmax = getattr(defaults, 's_tmax', self.s_tmax) or self.s_tmax # 0 = inf s_noise = getattr(defaults, 's_noise', self.s_noise) if 's_churn' in extra_params_kwargs and s_churn != self.s_churn: extra_params_kwargs['s_churn'] = s_churn self.s_churn = s_churn if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin: extra_params_kwargs['s_tmin'] = s_tmin self.s_tmin = s_tmin if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax: extra_params_kwargs['s_tmax'] = s_tmax self.s_tmax = s_tmax if 's_noise' in extra_params_kwargs and s_noise != self.s_noise: extra_params_kwargs['s_noise'] = s_noise self.s_noise = s_noise return extra_params_kwargs def create_noise_sampler(self, x, sigmas, seed): """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes""" from k_diffusion.sampling import BrownianTreeNoiseSampler sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed) def get_sigmas(self, steps, sigmas_min=None, sigmas_max=None): discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) steps += 1 if discard_next_to_last_sigma else 0 if self.scheduler_name == 'Automatic': self.scheduler_name = self.config.options.get('scheduler', None) scheduler = schedulers_map.get(self.scheduler_name) sigma_min = default(sigmas_min, self.model_wrap.sigma_min) sigma_max = default(sigmas_max, self.model_wrap.sigma_max) if scheduler is None or scheduler.function is None: sigmas = self.model_wrap.get_sigmas(steps) else: sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max} if scheduler.need_inner_model: sigmas_kwargs['inner_model'] = self.model_wrap sigmas = scheduler.function(n=steps, **sigmas_kwargs, device=self.device) if discard_next_to_last_sigma: sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) return sigmas def __call__(self, x, sigmas, sampler_extra_args, seed, deterministic, steps=None): x = x * sigmas[0] extra_params_kwargs = self.initialize() parameters = inspect.signature(self.func).parameters if 'n' in parameters: extra_params_kwargs['n'] = steps if 'sigma_min' in parameters: extra_params_kwargs['sigma_min'] = sigmas[sigmas > 0].min() extra_params_kwargs['sigma_max'] = sigmas.max() if 'sigmas' in parameters: extra_params_kwargs['sigmas'] = sigmas if self.config.options.get('brownian_noise', False): noise_sampler = self.create_noise_sampler(x, sigmas, seed) if deterministic else None extra_params_kwargs['noise_sampler'] = noise_sampler if self.config.options.get('solver_type', None) == 'heun': extra_params_kwargs['solver_type'] = 'heun' return self.func(self.sd, x, extra_args=sampler_extra_args, disable=False, **extra_params_kwargs)