| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| from typing import List, Optional, Dict |
| from tqdm import tqdm |
|
|
|
|
| class SinusoidalPositionEmbeddings(nn.Module): |
| def __init__(self, dim: int): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, time: torch.Tensor) -> torch.Tensor: |
| device = time.device |
| half_dim = self.dim // 2 |
| embeddings = math.log(10000) / (half_dim - 1) |
| embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) |
| embeddings = time[:, None] * embeddings[None, :] |
| embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) |
| return embeddings |
|
|
|
|
| class ResnetBlock1D(nn.Module): |
| def __init__(self, in_channels: int, out_channels: int, *, time_emb_dim: int = None, dropout: float = 0.1): |
| super().__init__() |
| self.time_mlp = nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(time_emb_dim, out_channels * 2) |
| ) if time_emb_dim is not None else None |
|
|
| self.block1_conv = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1) |
| self.block1_norm = nn.GroupNorm(8, out_channels, affine=False) |
| self.block1_act = nn.SiLU() |
|
|
| self.block2_conv = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1) |
| self.block2_norm = nn.GroupNorm(8, out_channels) |
| self.block2_act = nn.SiLU() |
| self.block2_dropout = nn.Dropout(dropout) |
|
|
| self.res_conv = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() |
|
|
| def forward(self, x: torch.Tensor, time_emb: torch.Tensor = None) -> torch.Tensor: |
| h = self.block1_conv(x) |
| h = self.block1_norm(h) |
|
|
| if self.time_mlp is not None and time_emb is not None: |
| scale_shift = self.time_mlp(time_emb) |
| scale, shift = scale_shift.chunk(2, dim=1) |
| h = h * (scale.unsqueeze(-1) + 1) + shift.unsqueeze(-1) |
|
|
| h = self.block1_act(h) |
|
|
| h = self.block2_act(self.block2_norm(self.block2_conv(h))) |
| h = self.block2_dropout(h) |
| return h + self.res_conv(x) |
|
|
|
|
| class AttentionBlock1D(nn.Module): |
| def __init__(self, channels: int, num_heads: int = 8): |
| super().__init__() |
| self.channels = channels |
| self.num_heads = num_heads |
| assert channels % num_heads == 0, "channels must be divisible by num_heads" |
| self.head_dim = channels // num_heads |
| |
| self.norm = nn.GroupNorm(8, channels) |
| self.qkv = nn.Conv1d(channels, channels * 3, 1) |
| self.proj = nn.Conv1d(channels, channels, 1) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| B, C, L = x.shape |
| h = self.norm(x) |
| |
| qkv = self.qkv(h) |
| qkv = qkv.view(B, 3, self.num_heads, self.head_dim, L) |
| qkv = qkv.permute(1, 0, 2, 4, 3) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
| |
| out = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) |
| |
| out = out.permute(0, 1, 3, 2) |
| out = out.contiguous().view(B, C, L) |
| |
| return x + self.proj(out) |
|
|
|
|
| class DownBlock1D(nn.Module): |
| def __init__(self, in_channels: int, out_channels: int, time_emb_dim: int, dropout: float, use_attention: bool, num_blocks: int = 2): |
| super().__init__() |
| self.resnets = nn.ModuleList([ |
| ResnetBlock1D(in_channels if i == 0 else out_channels, out_channels, time_emb_dim=time_emb_dim, dropout=dropout) |
| for i in range(num_blocks) |
| ]) |
| self.attn = AttentionBlock1D(out_channels) if use_attention else nn.Identity() |
| self.downsampler = nn.Conv1d(out_channels, out_channels, kernel_size=4, stride=2, padding=1) |
|
|
| def forward(self, x, time_emb): |
| for resnet in self.resnets: |
| x = resnet(x, time_emb) |
| x = self.attn(x) |
| skip = x |
| x = self.downsampler(x) |
| return x, skip |
|
|
|
|
| class UpBlock1D(nn.Module): |
| def __init__(self, in_channels: int, out_channels: int, time_emb_dim: int, dropout: float, use_attention: bool, num_blocks: int = 2): |
| super().__init__() |
| self.resnets = nn.ModuleList() |
| self.resnets.append(ResnetBlock1D(in_channels * 2, out_channels, time_emb_dim=time_emb_dim, dropout=dropout)) |
| for _ in range(num_blocks - 1): |
| self.resnets.append(ResnetBlock1D(out_channels, out_channels, time_emb_dim=time_emb_dim, dropout=dropout)) |
| self.attn = AttentionBlock1D(out_channels) if use_attention else nn.Identity() |
| self.upsampler = nn.ConvTranspose1d(in_channels, in_channels, kernel_size=4, stride=2, padding=1) |
|
|
| def forward(self, x, skip_x, time_emb): |
| x = self.upsampler(x) |
| |
| if x.size(-1) != skip_x.size(-1): |
| diff_L = skip_x.size(-1) - x.size(-1) |
| if diff_L > 0: |
| x = F.pad(x, [diff_L // 2, diff_L - diff_L // 2]) |
| elif diff_L < 0: |
| x = x[:, :, :skip_x.size(-1)] |
| |
| x = torch.cat([skip_x, x], dim=1) |
| |
| for resnet in self.resnets: |
| x = resnet(x, time_emb) |
| return self.attn(x) |
|
|
|
|
| class ConditionalUnet(nn.Module): |
| def __init__(self, in_channels: int, num_houses: int, embedding_dim: int = 64, |
| hidden_dims: List[int] = [64, 128, 256], |
| dropout: float = 0.1, use_attention: bool = True, |
| cond_channels: int = 0, blocks_per_level: int = 2): |
| super().__init__() |
| time_emb_dim = hidden_dims[0] * 4 |
|
|
| self.time_mlp = nn.Sequential( |
| SinusoidalPositionEmbeddings(hidden_dims[0]), |
| nn.Linear(hidden_dims[0], time_emb_dim), |
| nn.SiLU(), |
| nn.Linear(time_emb_dim, time_emb_dim) |
| ) |
| |
| self.house_embedding = nn.Embedding(num_houses, embedding_dim) |
| self.house_proj = nn.Linear(embedding_dim, time_emb_dim) |
|
|
| self.day_of_week_embedding = nn.Embedding(7, embedding_dim) |
| self.day_of_year_embedding = nn.Embedding(366, embedding_dim) |
| |
| self.day_of_week_proj = nn.Linear(embedding_dim, time_emb_dim) |
| self.day_of_year_proj = nn.Linear(embedding_dim, time_emb_dim) |
|
|
| self.init_conv = nn.Conv1d(in_channels + cond_channels, hidden_dims[0], kernel_size=7, padding=3) |
| |
| num_resolutions = len(hidden_dims) |
| self.down_blocks = nn.ModuleList([ |
| DownBlock1D(hidden_dims[i], hidden_dims[i+1], time_emb_dim, dropout, use_attention, blocks_per_level) |
| for i in range(num_resolutions - 1) |
| ]) |
| |
| self.mid_block1 = ResnetBlock1D(hidden_dims[-1], hidden_dims[-1], time_emb_dim=time_emb_dim, dropout=dropout) |
| self.mid_attn = AttentionBlock1D(hidden_dims[-1]) |
| self.mid_block2 = ResnetBlock1D(hidden_dims[-1], hidden_dims[-1], time_emb_dim=time_emb_dim, dropout=dropout) |
| |
| self.up_blocks = nn.ModuleList([ |
| UpBlock1D(hidden_dims[i+1], hidden_dims[i], time_emb_dim, dropout, use_attention, blocks_per_level) |
| for i in reversed(range(num_resolutions - 1)) |
| ]) |
| |
| self.final_conv = nn.Sequential( |
| ResnetBlock1D(hidden_dims[0], hidden_dims[0], time_emb_dim=time_emb_dim, dropout=dropout), |
| nn.Conv1d(hidden_dims[0], in_channels, 1) |
| ) |
|
|
| def forward(self, x: torch.Tensor, timestep: torch.Tensor, conditions: Dict[str, torch.Tensor], |
| conditioning_signal: Optional[torch.Tensor] = None) -> torch.Tensor: |
| time_emb = self.time_mlp(timestep) |
| |
| house_id = conditions["house_id"] |
| day_of_week = conditions["day_of_week"] |
| day_of_year = conditions["day_of_year"] |
|
|
| house_emb = self.house_proj(self.house_embedding(house_id)) |
| dow_emb = self.day_of_week_proj(self.day_of_week_embedding(day_of_week)) |
| doy_emb = self.day_of_year_proj(self.day_of_year_embedding(day_of_year)) |
| |
| emb = time_emb + house_emb + dow_emb + doy_emb |
|
|
| x = x.permute(0, 2, 1) |
| if conditioning_signal is not None: |
| x = torch.cat([x, conditioning_signal.permute(0, 2, 1)], dim=1) |
| |
| x = self.init_conv(x) |
| |
| skip_connections = [] |
| for down_block in self.down_blocks: |
| x, skip_x = down_block(x, emb) |
| skip_connections.append(skip_x) |
| |
| x = self.mid_block1(x, emb) |
| x = self.mid_attn(x) |
| x = self.mid_block2(x, emb) |
| |
| for up_block in self.up_blocks: |
| x = up_block(x, skip_connections.pop(), emb) |
| |
| return self.final_conv(x).permute(0, 2, 1) |
|
|
|
|
| class ImprovedDiffusionModel(nn.Module): |
| def __init__(self, base_model: ConditionalUnet, num_timesteps: int, channel_weights: torch.Tensor = None): |
| super().__init__() |
| self.model = base_model |
| self.num_timesteps = num_timesteps |
| self.channel_weights = channel_weights |
| |
| betas = self._cosine_beta_schedule(num_timesteps) |
| alphas = 1.0 - betas |
| alphas_cumprod = torch.cumprod(alphas, axis=0) |
| alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) |
| |
| self.register_buffer('betas', betas) |
| self.register_buffer('alphas', alphas) |
| self.register_buffer('alphas_cumprod', alphas_cumprod) |
| self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) |
| self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) |
| self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1.0 - alphas_cumprod)) |
| |
| posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) |
| posterior_variance = torch.clamp(posterior_variance, min=1e-20) |
| self.register_buffer('posterior_variance', posterior_variance) |
|
|
| def _cosine_beta_schedule(self, timesteps, s=0.008): |
| steps = timesteps + 1 |
| x = torch.linspace(0, timesteps, steps, dtype=torch.float64) |
| alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 |
| alphas_cumprod = alphas_cumprod / alphas_cumprod[0] |
| betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) |
| return torch.clip(betas, 0.0001, 0.9999).float() |
|
|
| def q_sample(self, x_start, t, noise=None): |
| if noise is None: noise = torch.randn_like(x_start) |
| sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].view(-1, 1, 1) |
| sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) |
| return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise |
|
|
| def forward(self, x_0: torch.Tensor, conditions: Dict[str, torch.Tensor], |
| conditioning_signal: Optional[torch.Tensor] = None) -> torch.Tensor: |
| t = torch.randint(0, self.num_timesteps, (x_0.shape[0],), device=x_0.device).long() |
| noise = torch.randn_like(x_0) |
| x_t = self.q_sample(x_0, t, noise) |
| predicted_noise = self.model(x_t, t, conditions, conditioning_signal) |
|
|
| |
| loss = F.huber_loss(noise, predicted_noise, reduction='none') |
|
|
| if self.channel_weights is not None: |
| |
| weights = self.channel_weights.to(loss.device).view(1, 1, -1) |
| loss = (loss * weights).mean() |
| else: |
| loss = loss.mean() |
|
|
| return loss |
| |
|
|
| @torch.no_grad() |
| def sample(self, num_samples: int, conditions: Dict[str, torch.Tensor], shape: tuple, |
| conditioning_signal: Optional[torch.Tensor] = None) -> torch.Tensor: |
| device = next(self.model.parameters()).device |
| x = torch.randn(num_samples, *shape, device=device) |
| |
| for t in tqdm(reversed(range(self.num_timesteps)), desc="Sampling", total=self.num_timesteps, leave=False): |
| t_batch = torch.full((num_samples,), t, device=device, dtype=torch.long) |
| predicted_noise = self.model(x, t_batch, conditions, conditioning_signal) |
| |
| alpha_t = self.alphas[t] |
| sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t] |
| |
| mean = (1 / torch.sqrt(alpha_t)) * (x - ((1 - alpha_t) / sqrt_one_minus_alpha_cumprod_t) * predicted_noise) |
| |
| if t > 0: |
| noise = torch.randn_like(x) |
| variance = self.posterior_variance[t] |
| x = mean + torch.sqrt(variance) * noise |
| else: |
| x = mean |
| |
| return x |
|
|
|
|
| class HierarchicalDiffusionModel(nn.Module): |
| def __init__(self, in_channels: int, num_houses: int, downscale_factor: int, channel_weights: Optional[torch.Tensor] = None, **model_kwargs): |
| super().__init__() |
| self.downscale_factor = downscale_factor |
| self.fine_chunk_size = 2 * 96 |
|
|
| |
| num_timesteps = model_kwargs.pop("num_timesteps") |
| |
| self.downsampler = nn.Conv1d(in_channels, in_channels, kernel_size=downscale_factor, stride=downscale_factor) |
| self.upsampler = nn.ConvTranspose1d(in_channels, in_channels, kernel_size=downscale_factor, stride=downscale_factor) |
| |
| |
| self.coarse_model = ImprovedDiffusionModel( |
| ConditionalUnet(in_channels=in_channels, num_houses=num_houses, **model_kwargs), |
| num_timesteps, |
| channel_weights=channel_weights |
| ) |
| self.fine_model = ImprovedDiffusionModel( |
| ConditionalUnet(in_channels=in_channels, num_houses=num_houses, |
| cond_channels=in_channels, **model_kwargs), |
| num_timesteps, |
| channel_weights=channel_weights |
| ) |
| |
| def forward(self, x_0: torch.Tensor, conditions: Dict[str, torch.Tensor]) -> torch.Tensor: |
| x_0_coarse = self.downsampler(x_0.permute(0, 2, 1)).permute(0, 2, 1) |
| coarse_loss = self.coarse_model(x_0_coarse, conditions) |
| |
| with torch.no_grad(): |
| x_0_coarse_upsampled = self.upsampler(x_0_coarse.detach().permute(0, 2, 1)).permute(0, 2, 1) |
| |
| if x_0_coarse_upsampled.shape[1] != x_0.shape[1]: |
| diff = x_0.shape[1] - x_0_coarse_upsampled.shape[1] |
| if diff > 0: x_0_coarse_upsampled = F.pad(x_0_coarse_upsampled, [0, 0, 0, diff]) |
| else: x_0_coarse_upsampled = x_0_coarse_upsampled[:, :x_0.shape[1], :] |
| x_0_fine_residual = x_0 - x_0_coarse_upsampled |
| |
| full_length = x_0.shape[1] |
| if full_length > self.fine_chunk_size: |
| start_index = torch.randint(0, full_length - self.fine_chunk_size + 1, (1,)).item() |
| else: |
| start_index = 0 |
| self.fine_chunk_size = full_length |
| |
| residual_chunk = x_0_fine_residual[:, start_index:start_index + self.fine_chunk_size, :] |
| conditioning_chunk = x_0_coarse_upsampled[:, start_index:start_index + self.fine_chunk_size, :] |
| |
| fine_loss = self.fine_model(residual_chunk, conditions, conditioning_signal=conditioning_chunk) |
| |
| fine_loss_weight = 1.5 |
| return coarse_loss + (fine_loss * fine_loss_weight) |
|
|
| @torch.no_grad() |
| def sample(self, num_samples: int, conditions: Dict[str, torch.Tensor], shape: tuple) -> torch.Tensor: |
| full_length, num_features = shape |
| device = next(self.parameters()).device |
| |
| conditions = {k: v.to(device) for k, v in conditions.items()} |
| |
| print("--- Stage 1: Sampling Coarse Structure ---") |
| coarse_shape = (full_length // self.downscale_factor, num_features) |
| generated_coarse = self.coarse_model.sample(num_samples, conditions, shape=coarse_shape) |
| upsampled_coarse = self.upsampler(generated_coarse.permute(0, 2, 1)).permute(0, 2, 1) |
| |
| if upsampled_coarse.shape[1] != full_length: |
| diff = full_length - upsampled_coarse.shape[1] |
| if diff > 0: upsampled_coarse = F.pad(upsampled_coarse, [0, 0, 0, diff]) |
| else: upsampled_coarse = upsampled_coarse[:, :full_length, :] |
| |
| print("--- Stage 2: Sampling Fine Details ---") |
| stitched_fine_residual = torch.zeros_like(upsampled_coarse) |
| |
| for start_index in tqdm(range(0, full_length, self.fine_chunk_size), desc="Fine chunks"): |
| end_index = min(start_index + self.fine_chunk_size, full_length) |
| chunk_length = end_index - start_index |
| fine_shape = (chunk_length, num_features) |
| conditioning_chunk = upsampled_coarse[:, start_index:end_index, :] |
| |
| generated_fine_chunk = self.fine_model.sample( |
| num_samples, conditions, shape=fine_shape, |
| conditioning_signal=conditioning_chunk |
| ) |
| |
| stitched_fine_residual[:, start_index:end_index, :] = generated_fine_chunk |
| |
| final_sample = upsampled_coarse + stitched_fine_residual |
| return final_sample |