File size: 1,785 Bytes
705a8fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from . import gaussian_diffusion as gd_orig
from . import gaussian_diffusion_dual as gd_dual
# from .respace import SpacedDiffusion, space_timesteps


def create_diffusion(
    timestep_respacing,
    noise_schedule="linear", 
    use_kl=False,
    sigma_small=False,
    predict_xstart=False,
    learn_sigma=True,
    rescale_learned_sigmas=False,
    diffusion_steps=1000,
    dual=False
):
    if dual:
        print("Using DUAL diffusion")
        from .respace_dual import SpacedDiffusion, space_timesteps
        gd_module = gd_dual
    else:
        print("Using SINGLE diffusion")
        from .respace import SpacedDiffusion, space_timesteps
        gd_module = gd_orig

    betas = gd_module.get_named_beta_schedule(noise_schedule, diffusion_steps)
    # betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
    if use_kl:
        loss_type = gd_module.LossType.RESCALED_KL
    elif rescale_learned_sigmas:
        loss_type = gd_module.LossType.RESCALED_MSE
    else:
        loss_type = gd_module.LossType.MSE
    if timestep_respacing is None or timestep_respacing == "":
        timestep_respacing = [diffusion_steps]
    return SpacedDiffusion(
        use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
        betas=betas,
        model_mean_type=(
            gd_module.ModelMeanType.EPSILON if not predict_xstart else gd_module.ModelMeanType.START_X
        ),
        model_var_type=(
            (
                gd_module.ModelVarType.FIXED_LARGE
                if not sigma_small
                else gd_module.ModelVarType.FIXED_SMALL
            )
            if not learn_sigma
            else gd_module.ModelVarType.LEARNED_RANGE
        ),
        loss_type=loss_type
        # rescale_timesteps=rescale_timesteps,
    )