File size: 2,562 Bytes
2a4c86a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from __future__ import annotations

from typing import Callable

import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput


class JiTScheduler(SchedulerMixin, ConfigMixin):
    order = 1

    @register_to_config
    def __init__(
        self,
        solver: str = "heun",
        timestep_start: float = 0.0,
        timestep_end: float = 1.0,
    ):
        if solver not in {"heun", "euler"}:
            raise ValueError("solver must be one of: 'heun', 'euler'.")
        if timestep_end <= timestep_start:
            raise ValueError("timestep_end must be greater than timestep_start.")
        self.timesteps = torch.tensor([])

    def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None):
        if num_inference_steps < 2:
            raise ValueError("num_inference_steps must be >= 2.")
        self.timesteps = torch.linspace(
            self.config.timestep_start,
            self.config.timestep_end,
            num_inference_steps + 1,
            device=device,
            dtype=torch.float32,
        )

    def euler_step(
        self,
        model_output: torch.Tensor,
        timestep: torch.Tensor,
        next_timestep: torch.Tensor,
        sample: torch.Tensor,
        return_dict: bool = True,
    ) -> SchedulerOutput | tuple[torch.Tensor]:
        prev_sample = sample + (next_timestep - timestep) * model_output
        if not return_dict:
            return (prev_sample,)
        return SchedulerOutput(prev_sample=prev_sample)

    def step(
        self,
        model_output: torch.Tensor,
        timestep: torch.Tensor,
        next_timestep: torch.Tensor,
        sample: torch.Tensor,
        model_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
        return_dict: bool = True,
    ) -> SchedulerOutput | tuple[torch.Tensor]:
        if self.config.solver == "euler":
            return self.euler_step(model_output, timestep, next_timestep, sample, return_dict=return_dict)

        if model_fn is None:
            raise ValueError("model_fn is required when solver='heun'.")

        sample_euler = sample + (next_timestep - timestep) * model_output
        model_output_next = model_fn(sample_euler, next_timestep)
        prev_sample = sample + (next_timestep - timestep) * 0.5 * (model_output + model_output_next)

        if not return_dict:
            return (prev_sample,)
        return SchedulerOutput(prev_sample=prev_sample)