File size: 2,562 Bytes
2a4c86a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | from __future__ import annotations
from typing import Callable
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
class JiTScheduler(SchedulerMixin, ConfigMixin):
order = 1
@register_to_config
def __init__(
self,
solver: str = "heun",
timestep_start: float = 0.0,
timestep_end: float = 1.0,
):
if solver not in {"heun", "euler"}:
raise ValueError("solver must be one of: 'heun', 'euler'.")
if timestep_end <= timestep_start:
raise ValueError("timestep_end must be greater than timestep_start.")
self.timesteps = torch.tensor([])
def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None):
if num_inference_steps < 2:
raise ValueError("num_inference_steps must be >= 2.")
self.timesteps = torch.linspace(
self.config.timestep_start,
self.config.timestep_end,
num_inference_steps + 1,
device=device,
dtype=torch.float32,
)
def euler_step(
self,
model_output: torch.Tensor,
timestep: torch.Tensor,
next_timestep: torch.Tensor,
sample: torch.Tensor,
return_dict: bool = True,
) -> SchedulerOutput | tuple[torch.Tensor]:
prev_sample = sample + (next_timestep - timestep) * model_output
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def step(
self,
model_output: torch.Tensor,
timestep: torch.Tensor,
next_timestep: torch.Tensor,
sample: torch.Tensor,
model_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
return_dict: bool = True,
) -> SchedulerOutput | tuple[torch.Tensor]:
if self.config.solver == "euler":
return self.euler_step(model_output, timestep, next_timestep, sample, return_dict=return_dict)
if model_fn is None:
raise ValueError("model_fn is required when solver='heun'.")
sample_euler = sample + (next_timestep - timestep) * model_output
model_output_next = model_fn(sample_euler, next_timestep)
prev_sample = sample + (next_timestep - timestep) * 0.5 * (model_output + model_output_next)
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
|