File size: 9,319 Bytes
1b34e16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import torch
import numpy as np

class Sampler:

    def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float=0.0120):
        #beta is a series of numbers that indicates the variance of the noise that we add with each of these steps
        # the start and end values were a choice made by the authors
        # will be using a linear scheduler, 1000 numbers between start and end

        self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2

        # alpha bar is the product of alpha going from 1 to T
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, 0) 
        self.one = torch.tensor(1.0)

        self.generator = generator
        self.num_training_steps = num_training_steps
        self.timesteps = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy())
    
    def set_inference_timesteps(self, num_inference_steps=50):
        self.num_inference_steps = num_inference_steps
        # 999, 998, 997, ... 0 = 1000 steps
        # 999, 999-20, 999-40, ... 0 = 50 steps
        step_ratio = self.num_training_steps // num_inference_steps
        timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
        self.timesteps = torch.from_numpy(timesteps)
    
    def _get_previous_timestep(self, timestep:int) -> int:
        prev_t = timestep - (self.num_training_steps // self.num_inference_steps)
        return prev_t
    
    def _get_variance(self, timestep: int) -> torch.Tensor:
        prev_t = self._get_previous_timestep(timestep)

        alpha_prod_t = self.alphas_cumprod[timestep]
        alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
        current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev

        # Computed using formula (7) of the DDPM paper
        variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
        variance = torch.clamp(variance, min=1e-20)

        return variance
    
    def set_strength(self, strength=1):
        start_step = self.num_inference_steps - int(self.num_inference_steps * strength)
        self.timesteps = self.timesteps[start_step:]
        self.start_step = start_step

    def ddpm_step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
        t = timestep
        prev_t = self._get_previous_timestep(t)

        alpha_prod_t = self.alphas_cumprod[t]
        alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev
        current_alpha_t = alpha_prod_t / alpha_prod_t_prev
        current_beta_t = 1 - current_alpha_t

        # Compute the predicted original sample using formula (15) of the DDPM paper
        pred_original_sample = (latents - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5

        # Compute the coefficient for pred_original_sample and current sample x_t
        pred_original_sample_coeff = (alpha_prod_t_prev ** 0.5 * current_beta_t) / beta_prod_t
        current_sample_coeff = current_alpha_t ** 0.5 * beta_prod_t_prev / beta_prod_t

        # Compute the predicted previous sample mean
        pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents

        variance = 0
        if t > 0:
            device = model_output.device
            noise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype)
            variance = (self._get_variance(t) ** 0.5) * noise
        
        # N(0,1) --> N(mu, sigma)
        # X = mu + sigma * Z where Z ~ N(0, 1)
        pred_prev_sample = pred_prev_sample + variance
        return pred_prev_sample
    
    def ddim_step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor, eta=0.0):
        t = timestep
        prev_t = self._get_previous_timestep(t)
        
        alpha_t = self.alphas_cumprod[t]
        alpha_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else torch.tensor(1.0, device=latents.device, dtype=latents.dtype)
        
        # Predicted original clean sample x_0
        pred_original_sample = (latents - torch.sqrt(1 - alpha_t) * model_output) / torch.sqrt(alpha_t)
        
        # Direction pointing to x_t
        #dir_xt = torch.sqrt(1 - alpha_prev - (eta ** 2) * ((1 - alpha_prev) / (1 - alpha_t)) * (1 - alpha_t / alpha_prev)) * model_output
        
        # Noise term
        noise = torch.randn_like(latents) if eta > 0 else torch.zeros_like(latents)
        
        sigma_t = eta * torch.sqrt((1 - alpha_prev) / (1 - alpha_t)) * torch.sqrt(1 - alpha_t / alpha_prev)
        
        # Compute previous latent x_{t-1}
        #prev_latent = torch.sqrt(alpha_prev) * pred_original_sample + dir_xt + sigma_t * noise
        prev_latent = torch.sqrt(alpha_prev) * pred_original_sample + torch.sqrt(1 - alpha_prev - sigma_t ** 2) * model_output + sigma_t * noise

        
        return prev_latent
    
    def euler_ancestral_step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor, eta=1.0):
        t = timestep
        prev_t = self._get_previous_timestep(t)
        
        # Convert alphas to sigmas (standard deviation of noise at each timestep)
        alpha_t = self.alphas_cumprod[t]
        alpha_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else torch.tensor(1.0, device=latents.device, dtype=latents.dtype)

        sigma_t = torch.sqrt(1 - alpha_t)
        sigma_prev = torch.sqrt(1 - alpha_prev)
        
        # Predict x_0
        x0_pred = (latents - sigma_t * model_output) / torch.sqrt(alpha_t)
        
        # Euler drift step (toward next timestep)
        dt = sigma_prev - sigma_t
        x_drift = latents + dt * model_output

        # Stochastic noise addition
        if eta > 0.0:
            noise = torch.randn_like(latents)
            sigma = torch.sqrt(torch.clamp(eta * (sigma_prev**2 - sigma_t**2), min=1e-20))
            x_drift += sigma * noise

        return x_drift
    
    def dpm_solver_pp_2m_step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
        """
        One DPM-Solver++(2M) step with DDIM-style signature.

        Args:
            timestep:       Current timestep index t.
            latents:        Latents at current timestep x_t.
            model_output:   Model prediction ε_θ(x_t, t).

        Returns:
            x_{t-1}:        Estimated latent at previous timestep.
        """
        t = self.timesteps[timestep]
        prev_t = self.timesteps[timestep + 1] if timestep + 1 < len(self.timesteps) else 0.0  # t_{prev}

        h = prev_t - t  # Note: time goes backward

        # Extract alpha and sigma for current and previous timesteps
        alpha_t = self.alphas_cumprod[timestep] ** 0.5
        alpha_prev = self.alphas_cumprod[timestep + 1] ** 0.5 if timestep + 1 < len(self.alphas_cumprod) else self.one
        sigma_t = (1 - self.alphas_cumprod[timestep]) ** 0.5
        sigma_prev = (1 - self.alphas_cumprod[timestep + 1]) ** 0.5 if timestep + 1 < len(self.alphas_cumprod) else self.zero

        # Store previous model output if not already done
        if not hasattr(self, "_prev_model_output"):
            self._prev_model_output = model_output  # Just initialize on first call

        model_output_t = model_output
        model_output_prev = self._prev_model_output

        # Compute x0_t and x0_prev estimates
        x0_t = (latents - sigma_t * model_output_t) / alpha_t
        x0_prev = (latents - sigma_t * model_output_prev) / alpha_t

        # 2nd-order multistep estimate
        x0_hat = x0_t + 0.5 * h * (model_output_t - model_output_prev)

        # Estimate x_{t-1}
        x_prev = alpha_prev * x0_hat + sigma_prev * model_output_prev

        # Update previous model output for next step
        self._prev_model_output = model_output

        return x_prev


    def add_noise(self, original_samples: torch.FloatTensor, timesteps: torch.IntTensor) -> torch.FloatTensor:
        #at what time we want to add the timestep
        alpha_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
        timesteps = timesteps.to(original_samples.device)
    
        sqrt_alpha_prod = alpha_cumprod[timesteps] ** 0.5
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) #adds a new dimension with length one at a specific pos within tensors shape
        sqrt_one_minus_alpha_prod = (1 - alpha_cumprod[timesteps]) ** 0.5 #standard deviation
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
        while len(sqrt_one_minus_alpha_prod) < len(original_samples.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
        
        # According to the euation (4) of the DDM paper
        noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)
        noisy_samples = (sqrt_alpha_prod * original_samples) + (sqrt_one_minus_alpha_prod) * noise