Spaces:
Running on Zero
Running on Zero
| import torch | |
| import k_diffusion | |
| import dataclasses | |
| class Scheduler: | |
| name: str | |
| label: str | |
| function: any | |
| default_rho: float = -1 | |
| need_inner_model: bool = False | |
| aliases: list = None | |
| def uniform(n, sigma_min, sigma_max, inner_model, device): | |
| return inner_model.get_sigmas(n) | |
| def sgm_uniform(n, sigma_min, sigma_max, inner_model, device): | |
| start = inner_model.sigma_to_t(torch.tensor(sigma_max)) | |
| end = inner_model.sigma_to_t(torch.tensor(sigma_min)) | |
| sigs = [ | |
| inner_model.t_to_sigma(ts) | |
| for ts in torch.linspace(start, end, n + 1)[:-1] | |
| ] | |
| sigs += [0.0] | |
| return torch.FloatTensor(sigs).to(device) | |
| schedulers = [ | |
| Scheduler('automatic', 'Automatic', None), | |
| Scheduler('uniform', 'Uniform', uniform, need_inner_model=True), | |
| Scheduler('karras', 'Karras', k_diffusion.sampling.get_sigmas_karras, default_rho=7.0), | |
| Scheduler('exponential', 'Exponential', k_diffusion.sampling.get_sigmas_exponential), | |
| Scheduler('polyexponential', 'Polyexponential', k_diffusion.sampling.get_sigmas_polyexponential, default_rho=1.0), | |
| Scheduler('sgm_uniform', 'SGM Uniform', sgm_uniform, need_inner_model=True, aliases=["SGMUniform"]), | |
| ] | |
| def get_noise_schedulers(): | |
| return [scheduler.label for scheduler in schedulers] | |
| schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}} |