| from __future__ import annotations |
|
|
| from typing import Callable |
|
|
| import torch |
| from diffusers.configuration_utils import ConfigMixin, register_to_config |
| from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput |
|
|
|
|
| class JiTScheduler(SchedulerMixin, ConfigMixin): |
| order = 1 |
|
|
| @register_to_config |
| def __init__( |
| self, |
| solver: str = "heun", |
| timestep_start: float = 0.0, |
| timestep_end: float = 1.0, |
| ): |
| if solver not in {"heun", "euler"}: |
| raise ValueError("solver must be one of: 'heun', 'euler'.") |
| if timestep_end <= timestep_start: |
| raise ValueError("timestep_end must be greater than timestep_start.") |
| self.timesteps = torch.tensor([]) |
|
|
| def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None): |
| if num_inference_steps < 2: |
| raise ValueError("num_inference_steps must be >= 2.") |
| self.timesteps = torch.linspace( |
| self.config.timestep_start, |
| self.config.timestep_end, |
| num_inference_steps + 1, |
| device=device, |
| dtype=torch.float32, |
| ) |
|
|
| def euler_step( |
| self, |
| model_output: torch.Tensor, |
| timestep: torch.Tensor, |
| next_timestep: torch.Tensor, |
| sample: torch.Tensor, |
| return_dict: bool = True, |
| ) -> SchedulerOutput | tuple[torch.Tensor]: |
| prev_sample = sample + (next_timestep - timestep) * model_output |
| if not return_dict: |
| return (prev_sample,) |
| return SchedulerOutput(prev_sample=prev_sample) |
|
|
| def step( |
| self, |
| model_output: torch.Tensor, |
| timestep: torch.Tensor, |
| next_timestep: torch.Tensor, |
| sample: torch.Tensor, |
| model_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, |
| return_dict: bool = True, |
| ) -> SchedulerOutput | tuple[torch.Tensor]: |
| if self.config.solver == "euler": |
| return self.euler_step(model_output, timestep, next_timestep, sample, return_dict=return_dict) |
|
|
| if model_fn is None: |
| raise ValueError("model_fn is required when solver='heun'.") |
|
|
| sample_euler = sample + (next_timestep - timestep) * model_output |
| model_output_next = model_fn(sample_euler, next_timestep) |
| prev_sample = sample + (next_timestep - timestep) * 0.5 * (model_output + model_output_next) |
|
|
| if not return_dict: |
| return (prev_sample,) |
| return SchedulerOutput(prev_sample=prev_sample) |
|
|