""" LiquidDiffusion Model — A Novel Attention-Free Image Generation Architecture Core Innovation: Parallel Liquid Neural Network blocks for image generation. The CfC (Closed-form Continuous-depth) time-gating mechanism naturally bridges with diffusion timesteps — the diffusion noise level IS the liquid time constant. Mathematical Foundation: CfC Eq.10: x(t) = σ(-f·t) ⊙ g + (1 - σ(-f·t)) ⊙ h For image generation, we adapt this as: φ'(t) = σ(-f(φ)·t_diff) ⊙ g(φ) + (1 - σ(-f(φ)·t_diff)) ⊙ h(φ) Where t_diff is the diffusion timestep, f/g/h are spatial feature transforms. This is FULLY PARALLEL — no ODE solver, no sequential scanning. Additionally, we use learnable exponential relaxation (from LiquidTAD): α = exp(-λ·t_diff), out = α·φ + (1-α)·S(φ) This gives depth-dependent, time-aware residual connections. Architecture: Input (noisy image) → Conv stem → [Encoder: DownBlocks with LiquidCfC] → Bottleneck (LiquidCfC) → [Decoder: UpBlocks with LiquidCfC + skip] → Conv head → Velocity prediction (for rectified flow) No attention anywhere. All spatial mixing via depthwise convolutions + multi-scale parallel processing in liquid blocks. References: [1] Hasani et al., "Closed-form Continuous-time Neural Networks", Nature MI 2022 (CfC) [2] arxiv 2604.18274 — LiquidTAD (parallel liquid relaxation) [3] arxiv 2504.13499 — USM (U-Shape Mamba for diffusion) [4] Liu et al., "Flow Straight and Fast: Rectified Flow", ICLR 2023 """ import math import torch import torch.nn as nn import torch.nn.functional as F # ============================================================================= # 1. TIME EMBEDDING — Sinusoidal + MLP # ============================================================================= class SinusoidalTimeEmbedding(nn.Module): """Maps scalar timestep t to a high-dimensional embedding. Uses sinusoidal positional encoding followed by 2-layer MLP. """ def __init__(self, dim: int, max_period: int = 10000): super().__init__() self.dim = dim self.max_period = max_period self.mlp = nn.Sequential( nn.Linear(dim, dim * 4), nn.SiLU(), nn.Linear(dim * 4, dim), ) def forward(self, t: torch.Tensor) -> torch.Tensor: """t: [B] timestep values in [0, 1] → [B, dim] embeddings""" half = self.dim // 2 freqs = torch.exp( -math.log(self.max_period) * torch.arange(half, device=t.device, dtype=t.dtype) / half ) args = t[:, None] * freqs[None, :] emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if self.dim % 2 == 1: emb = F.pad(emb, (0, 1)) return self.mlp(emb) # ============================================================================= # 2. ADAPTIVE LAYER NORM (AdaLN) — Timestep conditioning via scale/shift # ============================================================================= class AdaLN(nn.Module): """Adaptive Layer Normalization: out = norm(x) * (1 + scale(t)) + shift(t)""" def __init__(self, dim: int, cond_dim: int): super().__init__() # Find largest valid group count ≤ 32 num_groups = min(32, dim) while dim % num_groups != 0: num_groups -= 1 self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=dim, affine=False) self.proj = nn.Sequential(nn.SiLU(), nn.Linear(cond_dim, dim * 2)) def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor: """x: [B,C,H,W], t_emb: [B, cond_dim] → [B,C,H,W]""" scale, shift = self.proj(t_emb).chunk(2, dim=1) return self.norm(x) * (1 + scale[:, :, None, None]) + shift[:, :, None, None] # ============================================================================= # 3. PARALLEL CfC BLOCK — Core liquid neural network layer # ============================================================================= class ParallelCfCBlock(nn.Module): """Parallel Closed-form Continuous-depth block for spatial features. CfC Eq.10: x(t) = σ(-f·t) ⊙ g + (1 - σ(-f·t)) ⊙ h Optimized design: - Single depthwise conv in backbone provides spatial context - f/g/h heads are cheap 1×1 projections from the shared backbone - No redundant large-kernel convolutions in the heads - Liquid relaxation residual: α·input + (1-α)·CfC_output """ def __init__(self, dim: int, t_dim: int, expand_ratio: float = 2.0, kernel_size: int = 5, dropout: float = 0.0): super().__init__() hidden = int(dim * expand_ratio) # Shared backbone: ONE depthwise conv provides all spatial context self.backbone = nn.Sequential( nn.Conv2d(dim, dim, kernel_size, padding=kernel_size // 2, groups=dim), nn.Conv2d(dim, hidden, 1), nn.SiLU(), ) # Three CfC heads — all lightweight 1x1 projections (spatial info already in backbone) self.f_head = nn.Conv2d(hidden, dim, 1) # time-constant gate self.g_head = nn.Conv2d(hidden, dim, 1) # "from" state self.h_head = nn.Conv2d(hidden, dim, 1) # "to" state (attractor) # CfC time parameters self.time_a = nn.Linear(t_dim, dim) self.time_b = nn.Linear(t_dim, dim) # Liquid relaxation decay (LiquidTAD-inspired) self.rho = nn.Parameter(torch.zeros(1, dim, 1, 1)) # Output gate conditioned on timestep self.output_gate = nn.Sequential(nn.SiLU(), nn.Linear(t_dim, dim)) self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor: """x: [B,C,H,W], t_emb: [B, t_dim] → [B,C,H,W]""" residual = x # Shared backbone — single spatial conv + expand bb = self.backbone(x) # Three CfC heads (all 1x1 — fast) f = self.f_head(bb) g = self.g_head(bb) h = self.h_head(bb) # CfC time-gating: σ(time_a(t) · f - time_b(t)) ta = self.time_a(t_emb)[:, :, None, None] tb = self.time_b(t_emb)[:, :, None, None] gate = torch.sigmoid(ta * f - tb) # CfC interpolation: gate*g + (1-gate)*h cfc_out = self.dropout(gate * g + (1.0 - gate) * h) # Liquid relaxation: α = exp(-λ · |t_mean|) t_scalar = t_emb.mean(dim=1, keepdim=True)[:, :, None, None] alpha = torch.exp(-(F.softplus(self.rho) + 1e-6) * t_scalar.abs().clamp(min=0.01)) out = alpha * residual + (1.0 - alpha) * cfc_out # Output gate return out * torch.sigmoid(self.output_gate(t_emb))[:, :, None, None] # ============================================================================= # 4. MULTI-SCALE SPATIAL MIXING — Global context without attention # ============================================================================= class MultiScaleSpatialMix(nn.Module): """Spatial mixing via single large-kernel depthwise conv + global pooling. Replaces the previous 3-conv (3x3+5x5+7x7) design with a single depthwise conv for local context + global average pooling for global context. 2 branches instead of 4 → ~3x faster. """ def __init__(self, dim: int, t_dim: int, kernel_size: int = 7): super().__init__() self.local_dw = nn.Conv2d(dim, dim, kernel_size, padding=kernel_size // 2, groups=dim) self.global_pool = nn.AdaptiveAvgPool2d(1) self.global_proj = nn.Conv2d(dim, dim, 1) self.merge = nn.Conv2d(dim * 2, dim, 1) self.act = nn.SiLU() self.adaln = AdaLN(dim, t_dim) def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor: x_norm = self.adaln(x, t_emb) local_feat = self.local_dw(x_norm) global_feat = self.global_proj(self.global_pool(x_norm)).expand_as(x_norm) return x + self.act(self.merge(torch.cat([local_feat, global_feat], dim=1))) # ============================================================================= # 5. LIQUID DIFFUSION BLOCK — Complete processing unit # ============================================================================= class LiquidDiffusionBlock(nn.Module): """One complete LiquidDiffusion block: AdaLN → ParallelCfC → SpatialMix → FeedForward """ def __init__(self, dim: int, t_dim: int, expand_ratio: float = 2.0, kernel_size: int = 5, dropout: float = 0.0): super().__init__() self.adaln1 = AdaLN(dim, t_dim) self.cfc = ParallelCfCBlock(dim, t_dim, expand_ratio, kernel_size, dropout) self.spatial_mix = MultiScaleSpatialMix(dim, t_dim, kernel_size) self.adaln2 = AdaLN(dim, t_dim) ff_dim = int(dim * expand_ratio) self.ff = nn.Sequential( nn.Conv2d(dim, ff_dim, 1), nn.SiLU(), nn.Conv2d(ff_dim, dim, 1), ) self.res_scale = nn.Parameter(torch.ones(1) * 0.1) def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor: x = x + self.res_scale * self.cfc(self.adaln1(x, t_emb), t_emb) x = self.spatial_mix(x, t_emb) x = x + self.res_scale * self.ff(self.adaln2(x, t_emb)) return x # ============================================================================= # 6. DOWN/UP SAMPLE + SKIP FUSION # ============================================================================= class DownSample(nn.Module): """Strided convolution downsampling (2x).""" def __init__(self, in_dim: int, out_dim: int): super().__init__() self.conv = nn.Conv2d(in_dim, out_dim, 3, stride=2, padding=1) def forward(self, x): return self.conv(x) class UpSample(nn.Module): """Nearest-neighbor interpolation + conv upsampling (2x).""" def __init__(self, in_dim: int, out_dim: int): super().__init__() self.conv = nn.Conv2d(in_dim, out_dim, 3, padding=1) def forward(self, x): return self.conv(F.interpolate(x, scale_factor=2, mode='nearest')) class SkipFusion(nn.Module): """Timestep-gated skip connection fusion.""" def __init__(self, dim: int, t_dim: int): super().__init__() self.proj = nn.Conv2d(dim * 2, dim, 1) self.gate = nn.Sequential(nn.SiLU(), nn.Linear(t_dim, dim), nn.Sigmoid()) def forward(self, x, skip, t_emb): merged = self.proj(torch.cat([x, skip], dim=1)) g = self.gate(t_emb)[:, :, None, None] return merged * g + x * (1 - g) # ============================================================================= # 7. LIQUID DIFFUSION U-NET — The complete denoiser # ============================================================================= class LiquidDiffusionUNet(nn.Module): """LiquidDiffusion: Attention-Free Image Generation with Liquid Neural Networks. U-Net where every processing block uses Parallel CfC layers instead of attention. The diffusion timestep serves dual purpose: 1. Conditions the denoiser via AdaLN scale/shift 2. Acts as CfC "time parameter" — controlling liquid neuron interpolation Scales: tiny: channels=[64,128,256], blocks=[2,2,4], ~8M (256px, fast) small: channels=[96,192,384], blocks=[2,3,6], ~25M (256px, quality) base: channels=[128,256,512], blocks=[2,4,8], ~65M (512px) large: channels=[128,256,512,768],blocks=[2,4,8,4], ~120M (512px HQ) """ def __init__(self, in_channels=3, channels=None, blocks_per_stage=None, t_dim=256, expand_ratio=2.0, kernel_size=5, dropout=0.0): super().__init__() if channels is None: channels = [64, 128, 256] if blocks_per_stage is None: blocks_per_stage = [2, 2, 4] assert len(channels) == len(blocks_per_stage) self.channels = channels self.num_stages = len(channels) # Time embedding self.time_embed = SinusoidalTimeEmbedding(t_dim) # Input stem self.stem = nn.Sequential( nn.Conv2d(in_channels, channels[0], 3, padding=1), nn.SiLU(), nn.Conv2d(channels[0], channels[0], 3, padding=1), ) # Encoder self.encoder_blocks = nn.ModuleList() self.downsamplers = nn.ModuleList() for i in range(self.num_stages): stage = nn.ModuleList() for _ in range(blocks_per_stage[i]): stage.append(LiquidDiffusionBlock( channels[i], t_dim, expand_ratio, kernel_size, dropout)) self.encoder_blocks.append(stage) if i < self.num_stages - 1: self.downsamplers.append(DownSample(channels[i], channels[i + 1])) # Bottleneck self.bottleneck = nn.ModuleList([ LiquidDiffusionBlock(channels[-1], t_dim, expand_ratio, kernel_size, dropout), LiquidDiffusionBlock(channels[-1], t_dim, expand_ratio, kernel_size, dropout), ]) # Decoder self.decoder_blocks = nn.ModuleList() self.upsamplers = nn.ModuleList() self.skip_fusions = nn.ModuleList() for i in range(self.num_stages - 1, -1, -1): if i < self.num_stages - 1: self.upsamplers.append(UpSample(channels[i + 1], channels[i])) self.skip_fusions.append(SkipFusion(channels[i], t_dim)) stage = nn.ModuleList() for _ in range(blocks_per_stage[i]): stage.append(LiquidDiffusionBlock( channels[i], t_dim, expand_ratio, kernel_size, dropout)) self.decoder_blocks.append(stage) # Output head (initialized to zero for stable start) head_groups = min(32, channels[0]) while channels[0] % head_groups != 0: head_groups -= 1 self.head = nn.Sequential( nn.GroupNorm(head_groups, channels[0]), nn.SiLU(), nn.Conv2d(channels[0], in_channels, 3, padding=1), ) nn.init.zeros_(self.head[-1].weight) nn.init.zeros_(self.head[-1].bias) def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: """ Args: x: [B, C, H, W] noisy image t: [B] timestep values in [0, 1] Returns: [B, C, H, W] predicted velocity """ t_emb = self.time_embed(t) h = self.stem(x) # Encoder skips = [] for i in range(self.num_stages): for block in self.encoder_blocks[i]: h = block(h, t_emb) skips.append(h) if i < self.num_stages - 1: h = self.downsamplers[i](h) # Bottleneck for block in self.bottleneck: h = block(h, t_emb) # Decoder up_idx = 0 for dec_i in range(self.num_stages): stage_idx = self.num_stages - 1 - dec_i if dec_i > 0: h = self.upsamplers[up_idx](h) h = self.skip_fusions[up_idx](h, skips[stage_idx], t_emb) up_idx += 1 for block in self.decoder_blocks[dec_i]: h = block(h, t_emb) return self.head(h) def count_params(self): total = sum(p.numel() for p in self.parameters()) trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) return total, trainable # ============================================================================= # 8. MODEL CONFIGS # ============================================================================= def liquid_diffusion_tiny(**kwargs): """~23M params, 256px, fits ~6GB VRAM.""" return LiquidDiffusionUNet( channels=[64, 128, 256], blocks_per_stage=[2, 2, 4], t_dim=256, expand_ratio=2.0, kernel_size=5, **kwargs) def liquid_diffusion_small(**kwargs): """~69M params, 256px, fits ~10GB VRAM.""" return LiquidDiffusionUNet( channels=[96, 192, 384], blocks_per_stage=[2, 3, 6], t_dim=384, expand_ratio=2.0, kernel_size=5, **kwargs) def liquid_diffusion_base(**kwargs): """~154M params, 512px, fits ~16GB VRAM.""" return LiquidDiffusionUNet( channels=[128, 256, 512], blocks_per_stage=[2, 4, 8], t_dim=512, expand_ratio=2.0, kernel_size=5, **kwargs) def liquid_diffusion_large(**kwargs): """~120M params, 512px, needs ~24GB VRAM.""" return LiquidDiffusionUNet( channels=[128, 256, 512, 768], blocks_per_stage=[2, 4, 8, 4], t_dim=512, expand_ratio=2.0, kernel_size=5, **kwargs)