| """model.py — Neural network architectures for NSGF/NSGF++. |
| |
| Contains: |
| - VelocityMLP: MLP for 2D velocity field matching |
| - VelocityUNet: UNet for image velocity field matching (NSGF + NSF) |
| - PhaseTransitionPredictor: CNN for predicting transition time t_ϕ(x) |
| |
| Reference: arXiv:2401.14069, Appendix E.1, E.2 |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import List, Optional |
|
|
|
|
| class SinusoidalPosEmb(nn.Module): |
| def __init__(self, dim: int): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, t: torch.Tensor) -> torch.Tensor: |
| device = t.device |
| half_dim = self.dim // 2 |
| emb = math.log(10000.0) / (half_dim - 1) |
| emb = torch.exp(torch.arange(half_dim, device=device, dtype=torch.float32) * -emb) |
| emb = t.float().unsqueeze(-1) * emb.unsqueeze(0) |
| emb = torch.cat([emb.sin(), emb.cos()], dim=-1) |
| return emb |
|
|
|
|
| class VelocityMLP(nn.Module): |
| """MLP velocity field for 2D experiments. |
| Paper: 3 hidden layers, 256 hidden units, SiLU activation. |
| """ |
| def __init__(self, input_dim: int = 2, hidden_dim: int = 256, |
| num_hidden_layers: int = 3, time_emb_dim: int = 64): |
| super().__init__() |
| self.time_emb = SinusoidalPosEmb(time_emb_dim) |
| layers = [] |
| layers.append(nn.Linear(input_dim + time_emb_dim, hidden_dim)) |
| layers.append(nn.SiLU()) |
| for _ in range(num_hidden_layers - 1): |
| layers.append(nn.Linear(hidden_dim, hidden_dim)) |
| layers.append(nn.SiLU()) |
| layers.append(nn.Linear(hidden_dim, input_dim)) |
| self.net = nn.Sequential(*layers) |
|
|
| def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
| t_emb = self.time_emb(t) |
| xt = torch.cat([x, t_emb], dim=-1) |
| return self.net(xt) |
|
|
|
|
| class ResBlock(nn.Module): |
| """Residual block with AdaGN timestep conditioning.""" |
| def __init__(self, channels: int, emb_dim: int, out_channels: Optional[int] = None, |
| dropout: float = 0.0, use_scale_shift_norm: bool = True): |
| super().__init__() |
| self.out_channels = out_channels or channels |
| self.use_scale_shift_norm = use_scale_shift_norm |
| self.norm1 = nn.GroupNorm(32, channels) |
| self.conv1 = nn.Conv2d(channels, self.out_channels, 3, padding=1) |
| self.time_proj = nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(emb_dim, 2 * self.out_channels if use_scale_shift_norm else self.out_channels), |
| ) |
| self.norm2 = nn.GroupNorm(32, self.out_channels) |
| self.dropout = nn.Dropout(dropout) |
| self.conv2 = nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1) |
| if channels != self.out_channels: |
| self.skip = nn.Conv2d(channels, self.out_channels, 1) |
| else: |
| self.skip = nn.Identity() |
| nn.init.zeros_(self.conv2.weight) |
| nn.init.zeros_(self.conv2.bias) |
|
|
| def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: |
| h = self.norm1(x) |
| h = F.silu(h) |
| h = self.conv1(h) |
| emb_out = self.time_proj(emb)[:, :, None, None] |
| if self.use_scale_shift_norm: |
| scale, shift = emb_out.chunk(2, dim=1) |
| h = self.norm2(h) * (1 + scale) + shift |
| else: |
| h = self.norm2(h + emb_out) |
| h = F.silu(h) |
| h = self.dropout(h) |
| h = self.conv2(h) |
| return h + self.skip(x) |
|
|
|
|
| class AttentionBlock(nn.Module): |
| def __init__(self, channels: int, num_heads: int = 1, num_head_channels: int = -1): |
| super().__init__() |
| if num_head_channels > 0: |
| self.num_heads = channels // num_head_channels |
| else: |
| self.num_heads = num_heads |
| self.norm = nn.GroupNorm(32, channels) |
| self.qkv = nn.Conv1d(channels, channels * 3, 1) |
| self.proj = nn.Conv1d(channels, channels, 1) |
| nn.init.zeros_(self.proj.weight) |
| nn.init.zeros_(self.proj.bias) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| B, C, H, W = x.shape |
| h = self.norm(x).view(B, C, -1) |
| qkv = self.qkv(h).reshape(B, 3, self.num_heads, C // self.num_heads, -1) |
| q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2] |
| scale = (C // self.num_heads) ** -0.5 |
| attn = torch.einsum("bhcn,bhcm->bhnm", q, k) * scale |
| attn = attn.softmax(dim=-1) |
| out = torch.einsum("bhnm,bhcm->bhcn", attn, v) |
| out = out.reshape(B, C, -1) |
| out = self.proj(out).view(B, C, H, W) |
| return x + out |
|
|
|
|
| class Downsample(nn.Module): |
| def __init__(self, channels: int): |
| super().__init__() |
| self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1) |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.conv(x) |
|
|
|
|
| class Upsample(nn.Module): |
| def __init__(self, channels: int): |
| super().__init__() |
| self.conv = nn.Conv2d(channels, channels, 3, padding=1) |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = F.interpolate(x, scale_factor=2, mode="nearest") |
| return self.conv(x) |
|
|
|
|
| class VelocityUNet(nn.Module): |
| """UNet velocity field for image experiments (Dhariwal & Nichol 2021). |
| |
| MNIST: channels=32, depth=1, ch_mult=[1,2,2], heads=1 |
| CIFAR-10: channels=128, depth=2, ch_mult=[1,2,2,2], heads=4, head_ch=64 |
| """ |
| def __init__(self, image_size: int = 32, in_channels: int = 3, |
| model_channels: int = 128, num_res_blocks: int = 2, |
| channel_mult: List[int] = [1, 2, 2, 2], |
| attention_resolutions: List[int] = [16], |
| num_heads: int = 4, num_head_channels: int = 64, |
| dropout: float = 0.0, use_scale_shift_norm: bool = True): |
| super().__init__() |
| self.image_size = image_size |
| self.in_channels = in_channels |
| self.model_channels = model_channels |
|
|
| time_dim = model_channels * 4 |
| self.time_embed = nn.Sequential( |
| SinusoidalPosEmb(model_channels), |
| nn.Linear(model_channels, time_dim), nn.SiLU(), |
| nn.Linear(time_dim, time_dim), |
| ) |
| self.input_conv = nn.Conv2d(in_channels, model_channels, 3, padding=1) |
|
|
| self.down_blocks = nn.ModuleList() |
| self.down_attns = nn.ModuleList() |
| self.downsamplers = nn.ModuleList() |
|
|
| ch = model_channels |
| ds = image_size |
| input_block_channels = [ch] |
|
|
| for level, mult in enumerate(channel_mult): |
| out_ch = model_channels * mult |
| for _ in range(num_res_blocks): |
| block = ResBlock(ch, time_dim, out_ch, dropout, use_scale_shift_norm) |
| self.down_blocks.append(block) |
| if ds in attention_resolutions: |
| self.down_attns.append(AttentionBlock(out_ch, num_heads, num_head_channels)) |
| else: |
| self.down_attns.append(nn.Identity()) |
| ch = out_ch |
| input_block_channels.append(ch) |
| if level < len(channel_mult) - 1: |
| self.downsamplers.append(Downsample(ch)) |
| ds //= 2 |
| input_block_channels.append(ch) |
| else: |
| self.downsamplers.append(nn.Identity()) |
|
|
| self.mid_block1 = ResBlock(ch, time_dim, ch, dropout, use_scale_shift_norm) |
| self.mid_attn = AttentionBlock(ch, num_heads, num_head_channels) |
| self.mid_block2 = ResBlock(ch, time_dim, ch, dropout, use_scale_shift_norm) |
|
|
| self.up_blocks = nn.ModuleList() |
| self.up_attns = nn.ModuleList() |
| self.upsamplers = nn.ModuleList() |
|
|
| for level in reversed(range(len(channel_mult))): |
| mult = channel_mult[level] |
| out_ch = model_channels * mult |
| for i in range(num_res_blocks + 1): |
| skip_ch = input_block_channels.pop() |
| block = ResBlock(ch + skip_ch, time_dim, out_ch, dropout, use_scale_shift_norm) |
| self.up_blocks.append(block) |
| if ds in attention_resolutions: |
| self.up_attns.append(AttentionBlock(out_ch, num_heads, num_head_channels)) |
| else: |
| self.up_attns.append(nn.Identity()) |
| ch = out_ch |
| if level > 0: |
| self.upsamplers.append(Upsample(ch)) |
| ds *= 2 |
| else: |
| self.upsamplers.append(nn.Identity()) |
|
|
| self.out_norm = nn.GroupNorm(32, ch) |
| self.out_conv = nn.Conv2d(ch, in_channels, 3, padding=1) |
| nn.init.zeros_(self.out_conv.weight) |
| nn.init.zeros_(self.out_conv.bias) |
|
|
| def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
| emb = self.time_embed(t * 1000.0) |
| h = self.input_conv(x) |
| skips = [h] |
|
|
| block_idx = 0 |
| for level in range(len(self.downsamplers)): |
| for _ in range(self._get_num_res_blocks()): |
| if block_idx < len(self.down_blocks): |
| h = self.down_blocks[block_idx](h, emb) |
| h = self.down_attns[block_idx](h) |
| skips.append(h) |
| block_idx += 1 |
| if not isinstance(self.downsamplers[level], nn.Identity): |
| h = self.downsamplers[level](h) |
| skips.append(h) |
|
|
| h = self.mid_block1(h, emb) |
| h = self.mid_attn(h) |
| h = self.mid_block2(h, emb) |
|
|
| block_idx = 0 |
| for level in range(len(self.upsamplers)): |
| for _ in range(self._get_num_res_blocks() + 1): |
| if block_idx < len(self.up_blocks): |
| skip = skips.pop() |
| h = torch.cat([h, skip], dim=1) |
| h = self.up_blocks[block_idx](h, emb) |
| h = self.up_attns[block_idx](h) |
| block_idx += 1 |
| if not isinstance(self.upsamplers[level], nn.Identity): |
| h = self.upsamplers[level](h) |
|
|
| h = self.out_norm(h) |
| h = F.silu(h) |
| h = self.out_conv(h) |
| return h |
|
|
| def _get_num_res_blocks(self): |
| total_down = len(self.down_blocks) |
| num_levels = len(self.downsamplers) |
| return total_down // num_levels |
|
|
|
|
| class PhaseTransitionPredictor(nn.Module): |
| """CNN predicting phase-transition time t_ϕ(x) ∈ [0, 1]. |
| 4 conv layers (32→64→128→256), 3x3, AvgPool2d, FC + Sigmoid. |
| """ |
| def __init__(self, in_channels: int = 1, image_size: int = 28, |
| conv_channels: List[int] = [32, 64, 128, 256]): |
| super().__init__() |
| layers = [] |
| ch = in_channels |
| for out_ch in conv_channels: |
| layers.extend([ |
| nn.Conv2d(ch, out_ch, kernel_size=3, stride=1, padding=1), |
| nn.ReLU(inplace=True), |
| nn.AvgPool2d(kernel_size=2, stride=2), |
| ]) |
| ch = out_ch |
| self.conv = nn.Sequential(*layers) |
| final_size = image_size |
| for _ in conv_channels: |
| final_size = final_size // 2 |
| self.fc_input_dim = conv_channels[-1] * final_size * final_size |
| self.fc = nn.Linear(self.fc_input_dim, 1) |
| self.sigmoid = nn.Sigmoid() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| h = self.conv(x) |
| h = h.view(h.size(0), -1) |
| h = self.fc(h) |
| return self.sigmoid(h).squeeze(-1) |
|
|
|
|
| |
| def create_velocity_model_2d(config: dict) -> VelocityMLP: |
| model_cfg = config.get("model", {}) |
| return VelocityMLP( |
| input_dim=model_cfg.get("input_dim", 2), |
| hidden_dim=model_cfg.get("hidden_dim", 256), |
| num_hidden_layers=model_cfg.get("num_hidden_layers", 3), |
| time_emb_dim=model_cfg.get("time_emb_dim", 64), |
| ) |
|
|
| def create_velocity_unet(config: dict) -> VelocityUNet: |
| unet_cfg = config.get("unet", {}) |
| return VelocityUNet( |
| image_size=config.get("image_size", 32), |
| in_channels=config.get("in_channels", 3), |
| model_channels=unet_cfg.get("model_channels", 128), |
| num_res_blocks=unet_cfg.get("num_res_blocks", 2), |
| channel_mult=unet_cfg.get("channel_mult", [1, 2, 2, 2]), |
| attention_resolutions=unet_cfg.get("attention_resolutions", [16]), |
| num_heads=unet_cfg.get("num_heads", 4), |
| num_head_channels=unet_cfg.get("num_head_channels", 64), |
| dropout=unet_cfg.get("dropout", 0.0), |
| use_scale_shift_norm=unet_cfg.get("use_scale_shift_norm", True), |
| ) |
|
|
| def create_phase_predictor(config: dict) -> PhaseTransitionPredictor: |
| tp_cfg = config.get("time_predictor", {}) |
| return PhaseTransitionPredictor( |
| in_channels=config.get("in_channels", 1), |
| image_size=config.get("image_size", 28), |
| conv_channels=tp_cfg.get("conv_channels", [32, 64, 128, 256]), |
| ) |
|
|