| | """ |
| | Gaussian Diffusion (DDPM) framework for PDE next-frame prediction. |
| | """ |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import math |
| |
|
| |
|
| | class GaussianDiffusion(nn.Module): |
| | """DDPM with linear beta schedule. |
| | |
| | Training: given (condition, target), add noise to target, predict noise. |
| | Sampling: iteratively denoise starting from Gaussian noise. |
| | |
| | Args: |
| | model: U-Net (or any eps-predicting network). |
| | timesteps: number of diffusion steps. |
| | beta_start: starting noise level. |
| | beta_end: ending noise level. |
| | """ |
| |
|
| | def __init__(self, model, timesteps=1000, beta_start=1e-4, beta_end=0.02): |
| | super().__init__() |
| | self.model = model |
| | self.T = timesteps |
| |
|
| | |
| | betas = torch.linspace(beta_start, beta_end, timesteps) |
| | alphas = 1.0 - betas |
| | alpha_bar = torch.cumprod(alphas, dim=0) |
| |
|
| | self.register_buffer("betas", betas) |
| | self.register_buffer("alphas", alphas) |
| | self.register_buffer("alpha_bar", alpha_bar) |
| | self.register_buffer("sqrt_alpha_bar", torch.sqrt(alpha_bar)) |
| | self.register_buffer("sqrt_one_minus_alpha_bar", torch.sqrt(1 - alpha_bar)) |
| | self.register_buffer("sqrt_recip_alpha", torch.sqrt(1.0 / alphas)) |
| | self.register_buffer( |
| | "posterior_variance", |
| | betas * (1.0 - F.pad(alpha_bar[:-1], (1, 0), value=1.0)) / (1.0 - alpha_bar), |
| | ) |
| |
|
| | def q_sample(self, x0, t, noise=None): |
| | """Forward process: add noise to x0 at timestep t.""" |
| | if noise is None: |
| | noise = torch.randn_like(x0) |
| | a = self.sqrt_alpha_bar[t][:, None, None, None] |
| | b = self.sqrt_one_minus_alpha_bar[t][:, None, None, None] |
| | return a * x0 + b * noise, noise |
| |
|
| | def training_loss(self, x_target, x_cond): |
| | """Compute training loss (predict noise). |
| | |
| | Args: |
| | x_target: clean target frames [B, C, H, W]. |
| | x_cond: condition frames [B, C, H, W]. |
| | |
| | Returns: |
| | scalar MSE loss. |
| | """ |
| | B = x_target.shape[0] |
| | t = torch.randint(0, self.T, (B,), device=x_target.device) |
| | noise = torch.randn_like(x_target) |
| | x_noisy, _ = self.q_sample(x_target, t, noise) |
| |
|
| | eps_pred = self.model(x_noisy, t, cond=x_cond) |
| | return F.mse_loss(eps_pred, noise) |
| |
|
| | @torch.no_grad() |
| | def sample(self, x_cond, shape=None): |
| | """Generate target frames by iterative denoising (DDPM). |
| | |
| | Args: |
| | x_cond: condition frames [B, C_cond, H, W]. |
| | shape: (B, C_out, H, W) of the target. Inferred if None. |
| | |
| | Returns: |
| | denoised sample [B, C_out, H, W]. |
| | """ |
| | device = x_cond.device |
| | if shape is None: |
| | shape = x_cond.shape |
| |
|
| | x = torch.randn(shape, device=device) |
| |
|
| | for i in reversed(range(self.T)): |
| | t = torch.full((shape[0],), i, device=device, dtype=torch.long) |
| | eps = self.model(x, t, cond=x_cond) |
| |
|
| | alpha = self.alphas[i] |
| | alpha_bar = self.alpha_bar[i] |
| | beta = self.betas[i] |
| |
|
| | mean = (1.0 / alpha.sqrt()) * (x - beta / (1 - alpha_bar).sqrt() * eps) |
| |
|
| | if i > 0: |
| | sigma = self.posterior_variance[i].sqrt() |
| | x = mean + sigma * torch.randn_like(x) |
| | else: |
| | x = mean |
| |
|
| | return x |
| |
|
| | @torch.no_grad() |
| | def sample_ddim(self, x_cond, shape=None, steps=50, eta=0.0): |
| | """DDIM accelerated sampling. |
| | |
| | Args: |
| | x_cond: condition [B, C_cond, H, W]. |
| | shape: target shape. |
| | steps: number of DDIM steps (<<T for speed). |
| | eta: stochasticity (0=deterministic DDIM, 1=DDPM). |
| | |
| | Returns: |
| | denoised sample [B, C_out, H, W]. |
| | """ |
| | device = x_cond.device |
| | if shape is None: |
| | shape = x_cond.shape |
| |
|
| | |
| | step_indices = torch.linspace(0, self.T - 1, steps + 1, dtype=torch.long, device=device) |
| | step_indices = step_indices.flip(0) |
| |
|
| | x = torch.randn(shape, device=device) |
| |
|
| | for idx in range(len(step_indices) - 1): |
| | t_cur = step_indices[idx] |
| | t_next = step_indices[idx + 1] |
| |
|
| | t_batch = t_cur.expand(shape[0]) |
| | eps = self.model(x, t_batch, cond=x_cond) |
| |
|
| | ab_cur = self.alpha_bar[t_cur] |
| | ab_next = self.alpha_bar[t_next] |
| |
|
| | |
| | x0_pred = (x - (1 - ab_cur).sqrt() * eps) / ab_cur.sqrt() |
| | x0_pred = x0_pred.clamp(-5, 5) |
| |
|
| | |
| | sigma = eta * ((1 - ab_next) / (1 - ab_cur) * (1 - ab_cur / ab_next)).sqrt() |
| | dir_xt = (1 - ab_next - sigma**2).sqrt() * eps |
| |
|
| | x = ab_next.sqrt() * x0_pred + dir_xt |
| | if sigma > 0: |
| | x = x + sigma * torch.randn_like(x) |
| |
|
| | return x |
| |
|