"""model.py — Neural network architectures for NSGF/NSGF++. Contains: - VelocityMLP: MLP for 2D velocity field matching - VelocityUNet: UNet for image velocity field matching (NSGF + NSF) - PhaseTransitionPredictor: CNN for predicting transition time t_ϕ(x) Reference: arXiv:2401.14069, Appendix E.1, E.2 """ import math import torch import torch.nn as nn import torch.nn.functional as F from typing import List, Optional class SinusoidalPosEmb(nn.Module): def __init__(self, dim: int): super().__init__() self.dim = dim def forward(self, t: torch.Tensor) -> torch.Tensor: device = t.device half_dim = self.dim // 2 emb = math.log(10000.0) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device, dtype=torch.float32) * -emb) emb = t.float().unsqueeze(-1) * emb.unsqueeze(0) emb = torch.cat([emb.sin(), emb.cos()], dim=-1) return emb class VelocityMLP(nn.Module): """MLP velocity field for 2D experiments. Paper: 3 hidden layers, 256 hidden units, SiLU activation. """ def __init__(self, input_dim: int = 2, hidden_dim: int = 256, num_hidden_layers: int = 3, time_emb_dim: int = 64): super().__init__() self.time_emb = SinusoidalPosEmb(time_emb_dim) layers = [] layers.append(nn.Linear(input_dim + time_emb_dim, hidden_dim)) layers.append(nn.SiLU()) for _ in range(num_hidden_layers - 1): layers.append(nn.Linear(hidden_dim, hidden_dim)) layers.append(nn.SiLU()) layers.append(nn.Linear(hidden_dim, input_dim)) self.net = nn.Sequential(*layers) def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: t_emb = self.time_emb(t) xt = torch.cat([x, t_emb], dim=-1) return self.net(xt) class ResBlock(nn.Module): """Residual block with AdaGN timestep conditioning.""" def __init__(self, channels: int, emb_dim: int, out_channels: Optional[int] = None, dropout: float = 0.0, use_scale_shift_norm: bool = True): super().__init__() self.out_channels = out_channels or channels self.use_scale_shift_norm = use_scale_shift_norm self.norm1 = nn.GroupNorm(32, channels) self.conv1 = nn.Conv2d(channels, self.out_channels, 3, padding=1) self.time_proj = nn.Sequential( nn.SiLU(), nn.Linear(emb_dim, 2 * self.out_channels if use_scale_shift_norm else self.out_channels), ) self.norm2 = nn.GroupNorm(32, self.out_channels) self.dropout = nn.Dropout(dropout) self.conv2 = nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1) if channels != self.out_channels: self.skip = nn.Conv2d(channels, self.out_channels, 1) else: self.skip = nn.Identity() nn.init.zeros_(self.conv2.weight) nn.init.zeros_(self.conv2.bias) def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: h = self.norm1(x) h = F.silu(h) h = self.conv1(h) emb_out = self.time_proj(emb)[:, :, None, None] if self.use_scale_shift_norm: scale, shift = emb_out.chunk(2, dim=1) h = self.norm2(h) * (1 + scale) + shift else: h = self.norm2(h + emb_out) h = F.silu(h) h = self.dropout(h) h = self.conv2(h) return h + self.skip(x) class AttentionBlock(nn.Module): def __init__(self, channels: int, num_heads: int = 1, num_head_channels: int = -1): super().__init__() if num_head_channels > 0: self.num_heads = channels // num_head_channels else: self.num_heads = num_heads self.norm = nn.GroupNorm(32, channels) self.qkv = nn.Conv1d(channels, channels * 3, 1) self.proj = nn.Conv1d(channels, channels, 1) nn.init.zeros_(self.proj.weight) nn.init.zeros_(self.proj.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: B, C, H, W = x.shape h = self.norm(x).view(B, C, -1) qkv = self.qkv(h).reshape(B, 3, self.num_heads, C // self.num_heads, -1) q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2] scale = (C // self.num_heads) ** -0.5 attn = torch.einsum("bhcn,bhcm->bhnm", q, k) * scale attn = attn.softmax(dim=-1) out = torch.einsum("bhnm,bhcm->bhcn", attn, v) out = out.reshape(B, C, -1) out = self.proj(out).view(B, C, H, W) return x + out class Downsample(nn.Module): def __init__(self, channels: int): super().__init__() self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.conv(x) class Upsample(nn.Module): def __init__(self, channels: int): super().__init__() self.conv = nn.Conv2d(channels, channels, 3, padding=1) def forward(self, x: torch.Tensor) -> torch.Tensor: x = F.interpolate(x, scale_factor=2, mode="nearest") return self.conv(x) class VelocityUNet(nn.Module): """UNet velocity field for image experiments (Dhariwal & Nichol 2021). MNIST: channels=32, depth=1, ch_mult=[1,2,2], heads=1 CIFAR-10: channels=128, depth=2, ch_mult=[1,2,2,2], heads=4, head_ch=64 """ def __init__(self, image_size: int = 32, in_channels: int = 3, model_channels: int = 128, num_res_blocks: int = 2, channel_mult: List[int] = [1, 2, 2, 2], attention_resolutions: List[int] = [16], num_heads: int = 4, num_head_channels: int = 64, dropout: float = 0.0, use_scale_shift_norm: bool = True): super().__init__() self.image_size = image_size self.in_channels = in_channels self.model_channels = model_channels time_dim = model_channels * 4 self.time_embed = nn.Sequential( SinusoidalPosEmb(model_channels), nn.Linear(model_channels, time_dim), nn.SiLU(), nn.Linear(time_dim, time_dim), ) self.input_conv = nn.Conv2d(in_channels, model_channels, 3, padding=1) self.down_blocks = nn.ModuleList() self.down_attns = nn.ModuleList() self.downsamplers = nn.ModuleList() ch = model_channels ds = image_size input_block_channels = [ch] for level, mult in enumerate(channel_mult): out_ch = model_channels * mult for _ in range(num_res_blocks): block = ResBlock(ch, time_dim, out_ch, dropout, use_scale_shift_norm) self.down_blocks.append(block) if ds in attention_resolutions: self.down_attns.append(AttentionBlock(out_ch, num_heads, num_head_channels)) else: self.down_attns.append(nn.Identity()) ch = out_ch input_block_channels.append(ch) if level < len(channel_mult) - 1: self.downsamplers.append(Downsample(ch)) ds //= 2 input_block_channels.append(ch) else: self.downsamplers.append(nn.Identity()) self.mid_block1 = ResBlock(ch, time_dim, ch, dropout, use_scale_shift_norm) self.mid_attn = AttentionBlock(ch, num_heads, num_head_channels) self.mid_block2 = ResBlock(ch, time_dim, ch, dropout, use_scale_shift_norm) self.up_blocks = nn.ModuleList() self.up_attns = nn.ModuleList() self.upsamplers = nn.ModuleList() for level in reversed(range(len(channel_mult))): mult = channel_mult[level] out_ch = model_channels * mult for i in range(num_res_blocks + 1): skip_ch = input_block_channels.pop() block = ResBlock(ch + skip_ch, time_dim, out_ch, dropout, use_scale_shift_norm) self.up_blocks.append(block) if ds in attention_resolutions: self.up_attns.append(AttentionBlock(out_ch, num_heads, num_head_channels)) else: self.up_attns.append(nn.Identity()) ch = out_ch if level > 0: self.upsamplers.append(Upsample(ch)) ds *= 2 else: self.upsamplers.append(nn.Identity()) self.out_norm = nn.GroupNorm(32, ch) self.out_conv = nn.Conv2d(ch, in_channels, 3, padding=1) nn.init.zeros_(self.out_conv.weight) nn.init.zeros_(self.out_conv.bias) def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: emb = self.time_embed(t * 1000.0) h = self.input_conv(x) skips = [h] block_idx = 0 for level in range(len(self.downsamplers)): for _ in range(self._get_num_res_blocks()): if block_idx < len(self.down_blocks): h = self.down_blocks[block_idx](h, emb) h = self.down_attns[block_idx](h) skips.append(h) block_idx += 1 if not isinstance(self.downsamplers[level], nn.Identity): h = self.downsamplers[level](h) skips.append(h) h = self.mid_block1(h, emb) h = self.mid_attn(h) h = self.mid_block2(h, emb) block_idx = 0 for level in range(len(self.upsamplers)): for _ in range(self._get_num_res_blocks() + 1): if block_idx < len(self.up_blocks): skip = skips.pop() h = torch.cat([h, skip], dim=1) h = self.up_blocks[block_idx](h, emb) h = self.up_attns[block_idx](h) block_idx += 1 if not isinstance(self.upsamplers[level], nn.Identity): h = self.upsamplers[level](h) h = self.out_norm(h) h = F.silu(h) h = self.out_conv(h) return h def _get_num_res_blocks(self): total_down = len(self.down_blocks) num_levels = len(self.downsamplers) return total_down // num_levels class PhaseTransitionPredictor(nn.Module): """CNN predicting phase-transition time t_ϕ(x) ∈ [0, 1]. 4 conv layers (32→64→128→256), 3x3, AvgPool2d, FC + Sigmoid. """ def __init__(self, in_channels: int = 1, image_size: int = 28, conv_channels: List[int] = [32, 64, 128, 256]): super().__init__() layers = [] ch = in_channels for out_ch in conv_channels: layers.extend([ nn.Conv2d(ch, out_ch, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.AvgPool2d(kernel_size=2, stride=2), ]) ch = out_ch self.conv = nn.Sequential(*layers) final_size = image_size for _ in conv_channels: final_size = final_size // 2 self.fc_input_dim = conv_channels[-1] * final_size * final_size self.fc = nn.Linear(self.fc_input_dim, 1) self.sigmoid = nn.Sigmoid() def forward(self, x: torch.Tensor) -> torch.Tensor: h = self.conv(x) h = h.view(h.size(0), -1) h = self.fc(h) return self.sigmoid(h).squeeze(-1) # Factory functions def create_velocity_model_2d(config: dict) -> VelocityMLP: model_cfg = config.get("model", {}) return VelocityMLP( input_dim=model_cfg.get("input_dim", 2), hidden_dim=model_cfg.get("hidden_dim", 256), num_hidden_layers=model_cfg.get("num_hidden_layers", 3), time_emb_dim=model_cfg.get("time_emb_dim", 64), ) def create_velocity_unet(config: dict) -> VelocityUNet: unet_cfg = config.get("unet", {}) return VelocityUNet( image_size=config.get("image_size", 32), in_channels=config.get("in_channels", 3), model_channels=unet_cfg.get("model_channels", 128), num_res_blocks=unet_cfg.get("num_res_blocks", 2), channel_mult=unet_cfg.get("channel_mult", [1, 2, 2, 2]), attention_resolutions=unet_cfg.get("attention_resolutions", [16]), num_heads=unet_cfg.get("num_heads", 4), num_head_channels=unet_cfg.get("num_head_channels", 64), dropout=unet_cfg.get("dropout", 0.0), use_scale_shift_norm=unet_cfg.get("use_scale_shift_norm", True), ) def create_phase_predictor(config: dict) -> PhaseTransitionPredictor: tp_cfg = config.get("time_predictor", {}) return PhaseTransitionPredictor( in_channels=config.get("in_channels", 1), image_size=config.get("image_size", 28), conv_channels=tp_cfg.get("conv_channels", [32, 64, 128, 256]), )