| """ |
| LiquidDiffusion Model — A Novel Attention-Free Image Generation Architecture |
| |
| Core Innovation: Parallel Liquid Neural Network blocks for image generation. |
| The CfC (Closed-form Continuous-depth) time-gating mechanism naturally bridges |
| with diffusion timesteps — the diffusion noise level IS the liquid time constant. |
| |
| Mathematical Foundation: |
| CfC Eq.10: x(t) = σ(-f·t) ⊙ g + (1 - σ(-f·t)) ⊙ h |
| |
| For image generation, we adapt this as: |
| φ'(t) = σ(-f(φ)·t_diff) ⊙ g(φ) + (1 - σ(-f(φ)·t_diff)) ⊙ h(φ) |
| |
| Where t_diff is the diffusion timestep, f/g/h are spatial feature transforms. |
| This is FULLY PARALLEL — no ODE solver, no sequential scanning. |
| |
| Additionally, we use learnable exponential relaxation (from LiquidTAD): |
| α = exp(-λ·t_diff), out = α·φ + (1-α)·S(φ) |
| This gives depth-dependent, time-aware residual connections. |
| |
| Architecture: |
| Input (noisy image) → Conv stem → [Encoder: DownBlocks with LiquidCfC] |
| → Bottleneck (LiquidCfC) → [Decoder: UpBlocks with LiquidCfC + skip] |
| → Conv head → Velocity prediction (for rectified flow) |
| |
| No attention anywhere. All spatial mixing via depthwise convolutions + |
| multi-scale parallel processing in liquid blocks. |
| |
| References: |
| [1] Hasani et al., "Closed-form Continuous-time Neural Networks", Nature MI 2022 (CfC) |
| [2] arxiv 2604.18274 — LiquidTAD (parallel liquid relaxation) |
| [3] arxiv 2504.13499 — USM (U-Shape Mamba for diffusion) |
| [4] Liu et al., "Flow Straight and Fast: Rectified Flow", ICLR 2023 |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
|
|
| class SinusoidalTimeEmbedding(nn.Module): |
| """Maps scalar timestep t to a high-dimensional embedding. |
| Uses sinusoidal positional encoding followed by 2-layer MLP. |
| """ |
| def __init__(self, dim: int, max_period: int = 10000): |
| super().__init__() |
| self.dim = dim |
| self.max_period = max_period |
| self.mlp = nn.Sequential( |
| nn.Linear(dim, dim * 4), |
| nn.SiLU(), |
| nn.Linear(dim * 4, dim), |
| ) |
|
|
| def forward(self, t: torch.Tensor) -> torch.Tensor: |
| """t: [B] timestep values in [0, 1] → [B, dim] embeddings""" |
| half = self.dim // 2 |
| freqs = torch.exp( |
| -math.log(self.max_period) * torch.arange(half, device=t.device, dtype=t.dtype) / half |
| ) |
| args = t[:, None] * freqs[None, :] |
| emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| if self.dim % 2 == 1: |
| emb = F.pad(emb, (0, 1)) |
| return self.mlp(emb) |
|
|
|
|
| |
| |
| |
|
|
| class AdaLN(nn.Module): |
| """Adaptive Layer Normalization: out = norm(x) * (1 + scale(t)) + shift(t)""" |
| def __init__(self, dim: int, cond_dim: int): |
| super().__init__() |
| |
| num_groups = min(32, dim) |
| while dim % num_groups != 0: |
| num_groups -= 1 |
| self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=dim, affine=False) |
| self.proj = nn.Sequential(nn.SiLU(), nn.Linear(cond_dim, dim * 2)) |
|
|
| def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor: |
| """x: [B,C,H,W], t_emb: [B, cond_dim] → [B,C,H,W]""" |
| scale, shift = self.proj(t_emb).chunk(2, dim=1) |
| return self.norm(x) * (1 + scale[:, :, None, None]) + shift[:, :, None, None] |
|
|
|
|
| |
| |
| |
|
|
| class ParallelCfCBlock(nn.Module): |
| """Parallel Closed-form Continuous-depth block for spatial features. |
| |
| CfC Eq.10: x(t) = σ(-f·t) ⊙ g + (1 - σ(-f·t)) ⊙ h |
| |
| Optimized design: |
| - Single depthwise conv in backbone provides spatial context |
| - f/g/h heads are cheap 1×1 projections from the shared backbone |
| - No redundant large-kernel convolutions in the heads |
| - Liquid relaxation residual: α·input + (1-α)·CfC_output |
| """ |
| def __init__(self, dim: int, t_dim: int, expand_ratio: float = 2.0, |
| kernel_size: int = 5, dropout: float = 0.0): |
| super().__init__() |
| hidden = int(dim * expand_ratio) |
| |
| |
| self.backbone = nn.Sequential( |
| nn.Conv2d(dim, dim, kernel_size, padding=kernel_size // 2, groups=dim), |
| nn.Conv2d(dim, hidden, 1), |
| nn.SiLU(), |
| ) |
| |
| |
| self.f_head = nn.Conv2d(hidden, dim, 1) |
| self.g_head = nn.Conv2d(hidden, dim, 1) |
| self.h_head = nn.Conv2d(hidden, dim, 1) |
| |
| |
| self.time_a = nn.Linear(t_dim, dim) |
| self.time_b = nn.Linear(t_dim, dim) |
| |
| |
| self.rho = nn.Parameter(torch.zeros(1, dim, 1, 1)) |
| |
| |
| self.output_gate = nn.Sequential(nn.SiLU(), nn.Linear(t_dim, dim)) |
| |
| self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() |
|
|
| def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor: |
| """x: [B,C,H,W], t_emb: [B, t_dim] → [B,C,H,W]""" |
| residual = x |
| |
| |
| bb = self.backbone(x) |
| |
| |
| f = self.f_head(bb) |
| g = self.g_head(bb) |
| h = self.h_head(bb) |
| |
| |
| ta = self.time_a(t_emb)[:, :, None, None] |
| tb = self.time_b(t_emb)[:, :, None, None] |
| gate = torch.sigmoid(ta * f - tb) |
| |
| |
| cfc_out = self.dropout(gate * g + (1.0 - gate) * h) |
| |
| |
| t_scalar = t_emb.mean(dim=1, keepdim=True)[:, :, None, None] |
| alpha = torch.exp(-(F.softplus(self.rho) + 1e-6) * t_scalar.abs().clamp(min=0.01)) |
| |
| out = alpha * residual + (1.0 - alpha) * cfc_out |
| |
| |
| return out * torch.sigmoid(self.output_gate(t_emb))[:, :, None, None] |
|
|
|
|
| |
| |
| |
|
|
| class MultiScaleSpatialMix(nn.Module): |
| """Spatial mixing via single large-kernel depthwise conv + global pooling. |
| |
| Replaces the previous 3-conv (3x3+5x5+7x7) design with a single |
| depthwise conv for local context + global average pooling for global context. |
| 2 branches instead of 4 → ~3x faster. |
| """ |
| def __init__(self, dim: int, t_dim: int, kernel_size: int = 7): |
| super().__init__() |
| self.local_dw = nn.Conv2d(dim, dim, kernel_size, padding=kernel_size // 2, groups=dim) |
| self.global_pool = nn.AdaptiveAvgPool2d(1) |
| self.global_proj = nn.Conv2d(dim, dim, 1) |
| self.merge = nn.Conv2d(dim * 2, dim, 1) |
| self.act = nn.SiLU() |
| self.adaln = AdaLN(dim, t_dim) |
|
|
| def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor: |
| x_norm = self.adaln(x, t_emb) |
| local_feat = self.local_dw(x_norm) |
| global_feat = self.global_proj(self.global_pool(x_norm)).expand_as(x_norm) |
| return x + self.act(self.merge(torch.cat([local_feat, global_feat], dim=1))) |
|
|
|
|
| |
| |
| |
|
|
| class LiquidDiffusionBlock(nn.Module): |
| """One complete LiquidDiffusion block: |
| AdaLN → ParallelCfC → SpatialMix → FeedForward |
| """ |
| def __init__(self, dim: int, t_dim: int, expand_ratio: float = 2.0, |
| kernel_size: int = 5, dropout: float = 0.0): |
| super().__init__() |
| self.adaln1 = AdaLN(dim, t_dim) |
| self.cfc = ParallelCfCBlock(dim, t_dim, expand_ratio, kernel_size, dropout) |
| self.spatial_mix = MultiScaleSpatialMix(dim, t_dim, kernel_size) |
| self.adaln2 = AdaLN(dim, t_dim) |
| ff_dim = int(dim * expand_ratio) |
| self.ff = nn.Sequential( |
| nn.Conv2d(dim, ff_dim, 1), nn.SiLU(), nn.Conv2d(ff_dim, dim, 1), |
| ) |
| self.res_scale = nn.Parameter(torch.ones(1) * 0.1) |
|
|
| def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor: |
| x = x + self.res_scale * self.cfc(self.adaln1(x, t_emb), t_emb) |
| x = self.spatial_mix(x, t_emb) |
| x = x + self.res_scale * self.ff(self.adaln2(x, t_emb)) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class DownSample(nn.Module): |
| """Strided convolution downsampling (2x).""" |
| def __init__(self, in_dim: int, out_dim: int): |
| super().__init__() |
| self.conv = nn.Conv2d(in_dim, out_dim, 3, stride=2, padding=1) |
| def forward(self, x): |
| return self.conv(x) |
|
|
|
|
| class UpSample(nn.Module): |
| """Nearest-neighbor interpolation + conv upsampling (2x).""" |
| def __init__(self, in_dim: int, out_dim: int): |
| super().__init__() |
| self.conv = nn.Conv2d(in_dim, out_dim, 3, padding=1) |
| def forward(self, x): |
| return self.conv(F.interpolate(x, scale_factor=2, mode='nearest')) |
|
|
|
|
| class SkipFusion(nn.Module): |
| """Timestep-gated skip connection fusion.""" |
| def __init__(self, dim: int, t_dim: int): |
| super().__init__() |
| self.proj = nn.Conv2d(dim * 2, dim, 1) |
| self.gate = nn.Sequential(nn.SiLU(), nn.Linear(t_dim, dim), nn.Sigmoid()) |
| |
| def forward(self, x, skip, t_emb): |
| merged = self.proj(torch.cat([x, skip], dim=1)) |
| g = self.gate(t_emb)[:, :, None, None] |
| return merged * g + x * (1 - g) |
|
|
|
|
| |
| |
| |
|
|
| class LiquidDiffusionUNet(nn.Module): |
| """LiquidDiffusion: Attention-Free Image Generation with Liquid Neural Networks. |
| |
| U-Net where every processing block uses Parallel CfC layers instead of attention. |
| The diffusion timestep serves dual purpose: |
| 1. Conditions the denoiser via AdaLN scale/shift |
| 2. Acts as CfC "time parameter" — controlling liquid neuron interpolation |
| |
| Scales: |
| tiny: channels=[64,128,256], blocks=[2,2,4], ~8M (256px, fast) |
| small: channels=[96,192,384], blocks=[2,3,6], ~25M (256px, quality) |
| base: channels=[128,256,512], blocks=[2,4,8], ~65M (512px) |
| large: channels=[128,256,512,768],blocks=[2,4,8,4], ~120M (512px HQ) |
| """ |
| def __init__(self, in_channels=3, channels=None, blocks_per_stage=None, |
| t_dim=256, expand_ratio=2.0, kernel_size=5, dropout=0.0): |
| super().__init__() |
| if channels is None: |
| channels = [64, 128, 256] |
| if blocks_per_stage is None: |
| blocks_per_stage = [2, 2, 4] |
|
|
| assert len(channels) == len(blocks_per_stage) |
| self.channels = channels |
| self.num_stages = len(channels) |
| |
| |
| self.time_embed = SinusoidalTimeEmbedding(t_dim) |
| |
| |
| self.stem = nn.Sequential( |
| nn.Conv2d(in_channels, channels[0], 3, padding=1), |
| nn.SiLU(), |
| nn.Conv2d(channels[0], channels[0], 3, padding=1), |
| ) |
| |
| |
| self.encoder_blocks = nn.ModuleList() |
| self.downsamplers = nn.ModuleList() |
| for i in range(self.num_stages): |
| stage = nn.ModuleList() |
| for _ in range(blocks_per_stage[i]): |
| stage.append(LiquidDiffusionBlock( |
| channels[i], t_dim, expand_ratio, kernel_size, dropout)) |
| self.encoder_blocks.append(stage) |
| if i < self.num_stages - 1: |
| self.downsamplers.append(DownSample(channels[i], channels[i + 1])) |
| |
| |
| self.bottleneck = nn.ModuleList([ |
| LiquidDiffusionBlock(channels[-1], t_dim, expand_ratio, kernel_size, dropout), |
| LiquidDiffusionBlock(channels[-1], t_dim, expand_ratio, kernel_size, dropout), |
| ]) |
| |
| |
| self.decoder_blocks = nn.ModuleList() |
| self.upsamplers = nn.ModuleList() |
| self.skip_fusions = nn.ModuleList() |
| for i in range(self.num_stages - 1, -1, -1): |
| if i < self.num_stages - 1: |
| self.upsamplers.append(UpSample(channels[i + 1], channels[i])) |
| self.skip_fusions.append(SkipFusion(channels[i], t_dim)) |
| stage = nn.ModuleList() |
| for _ in range(blocks_per_stage[i]): |
| stage.append(LiquidDiffusionBlock( |
| channels[i], t_dim, expand_ratio, kernel_size, dropout)) |
| self.decoder_blocks.append(stage) |
| |
| |
| head_groups = min(32, channels[0]) |
| while channels[0] % head_groups != 0: |
| head_groups -= 1 |
| self.head = nn.Sequential( |
| nn.GroupNorm(head_groups, channels[0]), |
| nn.SiLU(), |
| nn.Conv2d(channels[0], in_channels, 3, padding=1), |
| ) |
| nn.init.zeros_(self.head[-1].weight) |
| nn.init.zeros_(self.head[-1].bias) |
|
|
| def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| x: [B, C, H, W] noisy image |
| t: [B] timestep values in [0, 1] |
| Returns: |
| [B, C, H, W] predicted velocity |
| """ |
| t_emb = self.time_embed(t) |
| h = self.stem(x) |
| |
| |
| skips = [] |
| for i in range(self.num_stages): |
| for block in self.encoder_blocks[i]: |
| h = block(h, t_emb) |
| skips.append(h) |
| if i < self.num_stages - 1: |
| h = self.downsamplers[i](h) |
| |
| |
| for block in self.bottleneck: |
| h = block(h, t_emb) |
| |
| |
| up_idx = 0 |
| for dec_i in range(self.num_stages): |
| stage_idx = self.num_stages - 1 - dec_i |
| if dec_i > 0: |
| h = self.upsamplers[up_idx](h) |
| h = self.skip_fusions[up_idx](h, skips[stage_idx], t_emb) |
| up_idx += 1 |
| for block in self.decoder_blocks[dec_i]: |
| h = block(h, t_emb) |
| |
| return self.head(h) |
|
|
| def count_params(self): |
| total = sum(p.numel() for p in self.parameters()) |
| trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| return total, trainable |
|
|
|
|
| |
| |
| |
|
|
| def liquid_diffusion_tiny(**kwargs): |
| """~23M params, 256px, fits ~6GB VRAM.""" |
| return LiquidDiffusionUNet( |
| channels=[64, 128, 256], blocks_per_stage=[2, 2, 4], |
| t_dim=256, expand_ratio=2.0, kernel_size=5, **kwargs) |
|
|
| def liquid_diffusion_small(**kwargs): |
| """~69M params, 256px, fits ~10GB VRAM.""" |
| return LiquidDiffusionUNet( |
| channels=[96, 192, 384], blocks_per_stage=[2, 3, 6], |
| t_dim=384, expand_ratio=2.0, kernel_size=5, **kwargs) |
|
|
| def liquid_diffusion_base(**kwargs): |
| """~154M params, 512px, fits ~16GB VRAM.""" |
| return LiquidDiffusionUNet( |
| channels=[128, 256, 512], blocks_per_stage=[2, 4, 8], |
| t_dim=512, expand_ratio=2.0, kernel_size=5, **kwargs) |
|
|
| def liquid_diffusion_large(**kwargs): |
| """~120M params, 512px, needs ~24GB VRAM.""" |
| return LiquidDiffusionUNet( |
| channels=[128, 256, 512, 768], blocks_per_stage=[2, 4, 8, 4], |
| t_dim=512, expand_ratio=2.0, kernel_size=5, **kwargs) |
|
|