nsgf-plusplus / model.py
rogermt's picture
Upload model.py
abe114d verified
"""model.py — Neural network architectures for NSGF/NSGF++.
Contains:
- VelocityMLP: MLP for 2D velocity field matching
- VelocityUNet: UNet for image velocity field matching (NSGF + NSF)
- PhaseTransitionPredictor: CNN for predicting transition time t_ϕ(x)
Reference: arXiv:2401.14069, Appendix E.1, E.2
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, t: torch.Tensor) -> torch.Tensor:
device = t.device
half_dim = self.dim // 2
emb = math.log(10000.0) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device, dtype=torch.float32) * -emb)
emb = t.float().unsqueeze(-1) * emb.unsqueeze(0)
emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
return emb
class VelocityMLP(nn.Module):
"""MLP velocity field for 2D experiments.
Paper: 3 hidden layers, 256 hidden units, SiLU activation.
"""
def __init__(self, input_dim: int = 2, hidden_dim: int = 256,
num_hidden_layers: int = 3, time_emb_dim: int = 64):
super().__init__()
self.time_emb = SinusoidalPosEmb(time_emb_dim)
layers = []
layers.append(nn.Linear(input_dim + time_emb_dim, hidden_dim))
layers.append(nn.SiLU())
for _ in range(num_hidden_layers - 1):
layers.append(nn.Linear(hidden_dim, hidden_dim))
layers.append(nn.SiLU())
layers.append(nn.Linear(hidden_dim, input_dim))
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
t_emb = self.time_emb(t)
xt = torch.cat([x, t_emb], dim=-1)
return self.net(xt)
class ResBlock(nn.Module):
"""Residual block with AdaGN timestep conditioning."""
def __init__(self, channels: int, emb_dim: int, out_channels: Optional[int] = None,
dropout: float = 0.0, use_scale_shift_norm: bool = True):
super().__init__()
self.out_channels = out_channels or channels
self.use_scale_shift_norm = use_scale_shift_norm
self.norm1 = nn.GroupNorm(32, channels)
self.conv1 = nn.Conv2d(channels, self.out_channels, 3, padding=1)
self.time_proj = nn.Sequential(
nn.SiLU(),
nn.Linear(emb_dim, 2 * self.out_channels if use_scale_shift_norm else self.out_channels),
)
self.norm2 = nn.GroupNorm(32, self.out_channels)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)
if channels != self.out_channels:
self.skip = nn.Conv2d(channels, self.out_channels, 1)
else:
self.skip = nn.Identity()
nn.init.zeros_(self.conv2.weight)
nn.init.zeros_(self.conv2.bias)
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
h = self.norm1(x)
h = F.silu(h)
h = self.conv1(h)
emb_out = self.time_proj(emb)[:, :, None, None]
if self.use_scale_shift_norm:
scale, shift = emb_out.chunk(2, dim=1)
h = self.norm2(h) * (1 + scale) + shift
else:
h = self.norm2(h + emb_out)
h = F.silu(h)
h = self.dropout(h)
h = self.conv2(h)
return h + self.skip(x)
class AttentionBlock(nn.Module):
def __init__(self, channels: int, num_heads: int = 1, num_head_channels: int = -1):
super().__init__()
if num_head_channels > 0:
self.num_heads = channels // num_head_channels
else:
self.num_heads = num_heads
self.norm = nn.GroupNorm(32, channels)
self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.proj = nn.Conv1d(channels, channels, 1)
nn.init.zeros_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
h = self.norm(x).view(B, C, -1)
qkv = self.qkv(h).reshape(B, 3, self.num_heads, C // self.num_heads, -1)
q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]
scale = (C // self.num_heads) ** -0.5
attn = torch.einsum("bhcn,bhcm->bhnm", q, k) * scale
attn = attn.softmax(dim=-1)
out = torch.einsum("bhnm,bhcm->bhcn", attn, v)
out = out.reshape(B, C, -1)
out = self.proj(out).view(B, C, H, W)
return x + out
class Downsample(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class Upsample(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.conv = nn.Conv2d(channels, channels, 3, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.interpolate(x, scale_factor=2, mode="nearest")
return self.conv(x)
class VelocityUNet(nn.Module):
"""UNet velocity field for image experiments (Dhariwal & Nichol 2021).
MNIST: channels=32, depth=1, ch_mult=[1,2,2], heads=1
CIFAR-10: channels=128, depth=2, ch_mult=[1,2,2,2], heads=4, head_ch=64
"""
def __init__(self, image_size: int = 32, in_channels: int = 3,
model_channels: int = 128, num_res_blocks: int = 2,
channel_mult: List[int] = [1, 2, 2, 2],
attention_resolutions: List[int] = [16],
num_heads: int = 4, num_head_channels: int = 64,
dropout: float = 0.0, use_scale_shift_norm: bool = True):
super().__init__()
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
time_dim = model_channels * 4
self.time_embed = nn.Sequential(
SinusoidalPosEmb(model_channels),
nn.Linear(model_channels, time_dim), nn.SiLU(),
nn.Linear(time_dim, time_dim),
)
self.input_conv = nn.Conv2d(in_channels, model_channels, 3, padding=1)
self.down_blocks = nn.ModuleList()
self.down_attns = nn.ModuleList()
self.downsamplers = nn.ModuleList()
ch = model_channels
ds = image_size
input_block_channels = [ch]
for level, mult in enumerate(channel_mult):
out_ch = model_channels * mult
for _ in range(num_res_blocks):
block = ResBlock(ch, time_dim, out_ch, dropout, use_scale_shift_norm)
self.down_blocks.append(block)
if ds in attention_resolutions:
self.down_attns.append(AttentionBlock(out_ch, num_heads, num_head_channels))
else:
self.down_attns.append(nn.Identity())
ch = out_ch
input_block_channels.append(ch)
if level < len(channel_mult) - 1:
self.downsamplers.append(Downsample(ch))
ds //= 2
input_block_channels.append(ch)
else:
self.downsamplers.append(nn.Identity())
self.mid_block1 = ResBlock(ch, time_dim, ch, dropout, use_scale_shift_norm)
self.mid_attn = AttentionBlock(ch, num_heads, num_head_channels)
self.mid_block2 = ResBlock(ch, time_dim, ch, dropout, use_scale_shift_norm)
self.up_blocks = nn.ModuleList()
self.up_attns = nn.ModuleList()
self.upsamplers = nn.ModuleList()
for level in reversed(range(len(channel_mult))):
mult = channel_mult[level]
out_ch = model_channels * mult
for i in range(num_res_blocks + 1):
skip_ch = input_block_channels.pop()
block = ResBlock(ch + skip_ch, time_dim, out_ch, dropout, use_scale_shift_norm)
self.up_blocks.append(block)
if ds in attention_resolutions:
self.up_attns.append(AttentionBlock(out_ch, num_heads, num_head_channels))
else:
self.up_attns.append(nn.Identity())
ch = out_ch
if level > 0:
self.upsamplers.append(Upsample(ch))
ds *= 2
else:
self.upsamplers.append(nn.Identity())
self.out_norm = nn.GroupNorm(32, ch)
self.out_conv = nn.Conv2d(ch, in_channels, 3, padding=1)
nn.init.zeros_(self.out_conv.weight)
nn.init.zeros_(self.out_conv.bias)
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
emb = self.time_embed(t * 1000.0)
h = self.input_conv(x)
skips = [h]
block_idx = 0
for level in range(len(self.downsamplers)):
for _ in range(self._get_num_res_blocks()):
if block_idx < len(self.down_blocks):
h = self.down_blocks[block_idx](h, emb)
h = self.down_attns[block_idx](h)
skips.append(h)
block_idx += 1
if not isinstance(self.downsamplers[level], nn.Identity):
h = self.downsamplers[level](h)
skips.append(h)
h = self.mid_block1(h, emb)
h = self.mid_attn(h)
h = self.mid_block2(h, emb)
block_idx = 0
for level in range(len(self.upsamplers)):
for _ in range(self._get_num_res_blocks() + 1):
if block_idx < len(self.up_blocks):
skip = skips.pop()
h = torch.cat([h, skip], dim=1)
h = self.up_blocks[block_idx](h, emb)
h = self.up_attns[block_idx](h)
block_idx += 1
if not isinstance(self.upsamplers[level], nn.Identity):
h = self.upsamplers[level](h)
h = self.out_norm(h)
h = F.silu(h)
h = self.out_conv(h)
return h
def _get_num_res_blocks(self):
total_down = len(self.down_blocks)
num_levels = len(self.downsamplers)
return total_down // num_levels
class PhaseTransitionPredictor(nn.Module):
"""CNN predicting phase-transition time t_ϕ(x) ∈ [0, 1].
4 conv layers (32→64→128→256), 3x3, AvgPool2d, FC + Sigmoid.
"""
def __init__(self, in_channels: int = 1, image_size: int = 28,
conv_channels: List[int] = [32, 64, 128, 256]):
super().__init__()
layers = []
ch = in_channels
for out_ch in conv_channels:
layers.extend([
nn.Conv2d(ch, out_ch, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.AvgPool2d(kernel_size=2, stride=2),
])
ch = out_ch
self.conv = nn.Sequential(*layers)
final_size = image_size
for _ in conv_channels:
final_size = final_size // 2
self.fc_input_dim = conv_channels[-1] * final_size * final_size
self.fc = nn.Linear(self.fc_input_dim, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.conv(x)
h = h.view(h.size(0), -1)
h = self.fc(h)
return self.sigmoid(h).squeeze(-1)
# Factory functions
def create_velocity_model_2d(config: dict) -> VelocityMLP:
model_cfg = config.get("model", {})
return VelocityMLP(
input_dim=model_cfg.get("input_dim", 2),
hidden_dim=model_cfg.get("hidden_dim", 256),
num_hidden_layers=model_cfg.get("num_hidden_layers", 3),
time_emb_dim=model_cfg.get("time_emb_dim", 64),
)
def create_velocity_unet(config: dict) -> VelocityUNet:
unet_cfg = config.get("unet", {})
return VelocityUNet(
image_size=config.get("image_size", 32),
in_channels=config.get("in_channels", 3),
model_channels=unet_cfg.get("model_channels", 128),
num_res_blocks=unet_cfg.get("num_res_blocks", 2),
channel_mult=unet_cfg.get("channel_mult", [1, 2, 2, 2]),
attention_resolutions=unet_cfg.get("attention_resolutions", [16]),
num_heads=unet_cfg.get("num_heads", 4),
num_head_channels=unet_cfg.get("num_head_channels", 64),
dropout=unet_cfg.get("dropout", 0.0),
use_scale_shift_norm=unet_cfg.get("use_scale_shift_norm", True),
)
def create_phase_predictor(config: dict) -> PhaseTransitionPredictor:
tp_cfg = config.get("time_predictor", {})
return PhaseTransitionPredictor(
in_channels=config.get("in_channels", 1),
image_size=config.get("image_size", 28),
conv_channels=tp_cfg.get("conv_channels", [32, 64, 128, 256]),
)