File size: 16,886 Bytes
8a00562
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19b13e1
 
 
 
 
8a00562
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
761b206
 
 
 
 
8a00562
 
761b206
8a00562
 
 
761b206
 
 
 
8a00562
 
 
761b206
 
 
 
 
 
8a00562
 
 
 
 
 
 
 
 
 
 
 
 
 
 
761b206
 
8a00562
761b206
 
 
 
8a00562
 
 
 
 
 
 
761b206
8a00562
 
 
761b206
8a00562
 
 
 
761b206
8a00562
 
 
 
 
 
 
761b206
8a00562
761b206
 
 
8a00562
761b206
8a00562
761b206
8a00562
 
761b206
8a00562
 
 
 
 
761b206
 
 
8a00562
 
 
 
 
 
 
 
761b206
8a00562
 
761b206
8a00562
 
 
761b206
8a00562
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
761b206
8a00562
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19b13e1
 
 
8a00562
19b13e1
8a00562
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19b13e1
8a00562
 
761b206
8a00562
 
19b13e1
8a00562
 
761b206
8a00562
 
19b13e1
8a00562
 
761b206
8a00562
 
 
 
 
761b206
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
"""
LiquidDiffusion Model — A Novel Attention-Free Image Generation Architecture

Core Innovation: Parallel Liquid Neural Network blocks for image generation.
The CfC (Closed-form Continuous-depth) time-gating mechanism naturally bridges
with diffusion timesteps — the diffusion noise level IS the liquid time constant.

Mathematical Foundation:
    CfC Eq.10: x(t) = σ(-f·t) ⊙ g + (1 - σ(-f·t)) ⊙ h
    
    For image generation, we adapt this as:
    φ'(t) = σ(-f(φ)·t_diff) ⊙ g(φ) + (1 - σ(-f(φ)·t_diff)) ⊙ h(φ)
    
    Where t_diff is the diffusion timestep, f/g/h are spatial feature transforms.
    This is FULLY PARALLEL — no ODE solver, no sequential scanning.

    Additionally, we use learnable exponential relaxation (from LiquidTAD):
    α = exp(-λ·t_diff), out = α·φ + (1-α)·S(φ)
    This gives depth-dependent, time-aware residual connections.

Architecture:
    Input (noisy image) → Conv stem → [Encoder: DownBlocks with LiquidCfC]
    → Bottleneck (LiquidCfC) → [Decoder: UpBlocks with LiquidCfC + skip]
    → Conv head → Velocity prediction (for rectified flow)

No attention anywhere. All spatial mixing via depthwise convolutions +
multi-scale parallel processing in liquid blocks.

References:
    [1] Hasani et al., "Closed-form Continuous-time Neural Networks", Nature MI 2022 (CfC)
    [2] arxiv 2604.18274 — LiquidTAD (parallel liquid relaxation)
    [3] arxiv 2504.13499 — USM (U-Shape Mamba for diffusion)
    [4] Liu et al., "Flow Straight and Fast: Rectified Flow", ICLR 2023
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


# =============================================================================
# 1. TIME EMBEDDING — Sinusoidal + MLP
# =============================================================================

class SinusoidalTimeEmbedding(nn.Module):
    """Maps scalar timestep t to a high-dimensional embedding.
    Uses sinusoidal positional encoding followed by 2-layer MLP.
    """
    def __init__(self, dim: int, max_period: int = 10000):
        super().__init__()
        self.dim = dim
        self.max_period = max_period
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.SiLU(),
            nn.Linear(dim * 4, dim),
        )

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        """t: [B] timestep values in [0, 1] → [B, dim] embeddings"""
        half = self.dim // 2
        freqs = torch.exp(
            -math.log(self.max_period) * torch.arange(half, device=t.device, dtype=t.dtype) / half
        )
        args = t[:, None] * freqs[None, :]
        emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if self.dim % 2 == 1:
            emb = F.pad(emb, (0, 1))
        return self.mlp(emb)


# =============================================================================
# 2. ADAPTIVE LAYER NORM (AdaLN) — Timestep conditioning via scale/shift
# =============================================================================

class AdaLN(nn.Module):
    """Adaptive Layer Normalization: out = norm(x) * (1 + scale(t)) + shift(t)"""
    def __init__(self, dim: int, cond_dim: int):
        super().__init__()
        # Find largest valid group count ≤ 32
        num_groups = min(32, dim)
        while dim % num_groups != 0:
            num_groups -= 1
        self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=dim, affine=False)
        self.proj = nn.Sequential(nn.SiLU(), nn.Linear(cond_dim, dim * 2))

    def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
        """x: [B,C,H,W], t_emb: [B, cond_dim] → [B,C,H,W]"""
        scale, shift = self.proj(t_emb).chunk(2, dim=1)
        return self.norm(x) * (1 + scale[:, :, None, None]) + shift[:, :, None, None]


# =============================================================================
# 3. PARALLEL CfC BLOCK — Core liquid neural network layer
# =============================================================================

class ParallelCfCBlock(nn.Module):
    """Parallel Closed-form Continuous-depth block for spatial features.
    
    CfC Eq.10: x(t) = σ(-f·t) ⊙ g + (1 - σ(-f·t)) ⊙ h
    
    Optimized design:
    - Single depthwise conv in backbone provides spatial context
    - f/g/h heads are cheap 1×1 projections from the shared backbone
    - No redundant large-kernel convolutions in the heads
    - Liquid relaxation residual: α·input + (1-α)·CfC_output
    """
    def __init__(self, dim: int, t_dim: int, expand_ratio: float = 2.0,
                 kernel_size: int = 5, dropout: float = 0.0):
        super().__init__()
        hidden = int(dim * expand_ratio)
        
        # Shared backbone: ONE depthwise conv provides all spatial context
        self.backbone = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size, padding=kernel_size // 2, groups=dim),
            nn.Conv2d(dim, hidden, 1),
            nn.SiLU(),
        )
        
        # Three CfC heads — all lightweight 1x1 projections (spatial info already in backbone)
        self.f_head = nn.Conv2d(hidden, dim, 1)   # time-constant gate
        self.g_head = nn.Conv2d(hidden, dim, 1)   # "from" state
        self.h_head = nn.Conv2d(hidden, dim, 1)   # "to" state (attractor)
        
        # CfC time parameters
        self.time_a = nn.Linear(t_dim, dim)
        self.time_b = nn.Linear(t_dim, dim)
        
        # Liquid relaxation decay (LiquidTAD-inspired)
        self.rho = nn.Parameter(torch.zeros(1, dim, 1, 1))
        
        # Output gate conditioned on timestep
        self.output_gate = nn.Sequential(nn.SiLU(), nn.Linear(t_dim, dim))
        
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()

    def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
        """x: [B,C,H,W], t_emb: [B, t_dim] → [B,C,H,W]"""
        residual = x
        
        # Shared backbone — single spatial conv + expand
        bb = self.backbone(x)
        
        # Three CfC heads (all 1x1 — fast)
        f = self.f_head(bb)
        g = self.g_head(bb)
        h = self.h_head(bb)
        
        # CfC time-gating: σ(time_a(t) · f - time_b(t))
        ta = self.time_a(t_emb)[:, :, None, None]
        tb = self.time_b(t_emb)[:, :, None, None]
        gate = torch.sigmoid(ta * f - tb)
        
        # CfC interpolation: gate*g + (1-gate)*h
        cfc_out = self.dropout(gate * g + (1.0 - gate) * h)
        
        # Liquid relaxation: α = exp(-λ · |t_mean|)
        t_scalar = t_emb.mean(dim=1, keepdim=True)[:, :, None, None]
        alpha = torch.exp(-(F.softplus(self.rho) + 1e-6) * t_scalar.abs().clamp(min=0.01))
        
        out = alpha * residual + (1.0 - alpha) * cfc_out
        
        # Output gate
        return out * torch.sigmoid(self.output_gate(t_emb))[:, :, None, None]


# =============================================================================
# 4. MULTI-SCALE SPATIAL MIXING — Global context without attention
# =============================================================================

class MultiScaleSpatialMix(nn.Module):
    """Spatial mixing via single large-kernel depthwise conv + global pooling.
    
    Replaces the previous 3-conv (3x3+5x5+7x7) design with a single 
    depthwise conv for local context + global average pooling for global context.
    2 branches instead of 4 → ~3x faster.
    """
    def __init__(self, dim: int, t_dim: int, kernel_size: int = 7):
        super().__init__()
        self.local_dw = nn.Conv2d(dim, dim, kernel_size, padding=kernel_size // 2, groups=dim)
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.global_proj = nn.Conv2d(dim, dim, 1)
        self.merge = nn.Conv2d(dim * 2, dim, 1)
        self.act = nn.SiLU()
        self.adaln = AdaLN(dim, t_dim)

    def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
        x_norm = self.adaln(x, t_emb)
        local_feat = self.local_dw(x_norm)
        global_feat = self.global_proj(self.global_pool(x_norm)).expand_as(x_norm)
        return x + self.act(self.merge(torch.cat([local_feat, global_feat], dim=1)))


# =============================================================================
# 5. LIQUID DIFFUSION BLOCK — Complete processing unit
# =============================================================================

class LiquidDiffusionBlock(nn.Module):
    """One complete LiquidDiffusion block:
    AdaLN → ParallelCfC → SpatialMix → FeedForward
    """
    def __init__(self, dim: int, t_dim: int, expand_ratio: float = 2.0,
                 kernel_size: int = 5, dropout: float = 0.0):
        super().__init__()
        self.adaln1 = AdaLN(dim, t_dim)
        self.cfc = ParallelCfCBlock(dim, t_dim, expand_ratio, kernel_size, dropout)
        self.spatial_mix = MultiScaleSpatialMix(dim, t_dim, kernel_size)
        self.adaln2 = AdaLN(dim, t_dim)
        ff_dim = int(dim * expand_ratio)
        self.ff = nn.Sequential(
            nn.Conv2d(dim, ff_dim, 1), nn.SiLU(), nn.Conv2d(ff_dim, dim, 1),
        )
        self.res_scale = nn.Parameter(torch.ones(1) * 0.1)

    def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
        x = x + self.res_scale * self.cfc(self.adaln1(x, t_emb), t_emb)
        x = self.spatial_mix(x, t_emb)
        x = x + self.res_scale * self.ff(self.adaln2(x, t_emb))
        return x


# =============================================================================
# 6. DOWN/UP SAMPLE + SKIP FUSION
# =============================================================================

class DownSample(nn.Module):
    """Strided convolution downsampling (2x)."""
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.conv = nn.Conv2d(in_dim, out_dim, 3, stride=2, padding=1)
    def forward(self, x):
        return self.conv(x)


class UpSample(nn.Module):
    """Nearest-neighbor interpolation + conv upsampling (2x)."""
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.conv = nn.Conv2d(in_dim, out_dim, 3, padding=1)
    def forward(self, x):
        return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))


class SkipFusion(nn.Module):
    """Timestep-gated skip connection fusion."""
    def __init__(self, dim: int, t_dim: int):
        super().__init__()
        self.proj = nn.Conv2d(dim * 2, dim, 1)
        self.gate = nn.Sequential(nn.SiLU(), nn.Linear(t_dim, dim), nn.Sigmoid())
    
    def forward(self, x, skip, t_emb):
        merged = self.proj(torch.cat([x, skip], dim=1))
        g = self.gate(t_emb)[:, :, None, None]
        return merged * g + x * (1 - g)


# =============================================================================
# 7. LIQUID DIFFUSION U-NET — The complete denoiser
# =============================================================================

class LiquidDiffusionUNet(nn.Module):
    """LiquidDiffusion: Attention-Free Image Generation with Liquid Neural Networks.
    
    U-Net where every processing block uses Parallel CfC layers instead of attention.
    The diffusion timestep serves dual purpose:
        1. Conditions the denoiser via AdaLN scale/shift
        2. Acts as CfC "time parameter" — controlling liquid neuron interpolation
    
    Scales:
        tiny:  channels=[64,128,256],     blocks=[2,2,4],    ~8M   (256px, fast)
        small: channels=[96,192,384],     blocks=[2,3,6],    ~25M  (256px, quality)
        base:  channels=[128,256,512],    blocks=[2,4,8],    ~65M  (512px)
        large: channels=[128,256,512,768],blocks=[2,4,8,4],  ~120M (512px HQ)
    """
    def __init__(self, in_channels=3, channels=None, blocks_per_stage=None,
                 t_dim=256, expand_ratio=2.0, kernel_size=5, dropout=0.0):
        super().__init__()
        if channels is None:
            channels = [64, 128, 256]
        if blocks_per_stage is None:
            blocks_per_stage = [2, 2, 4]

        assert len(channels) == len(blocks_per_stage)
        self.channels = channels
        self.num_stages = len(channels)
        
        # Time embedding
        self.time_embed = SinusoidalTimeEmbedding(t_dim)
        
        # Input stem
        self.stem = nn.Sequential(
            nn.Conv2d(in_channels, channels[0], 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(channels[0], channels[0], 3, padding=1),
        )
        
        # Encoder
        self.encoder_blocks = nn.ModuleList()
        self.downsamplers = nn.ModuleList()
        for i in range(self.num_stages):
            stage = nn.ModuleList()
            for _ in range(blocks_per_stage[i]):
                stage.append(LiquidDiffusionBlock(
                    channels[i], t_dim, expand_ratio, kernel_size, dropout))
            self.encoder_blocks.append(stage)
            if i < self.num_stages - 1:
                self.downsamplers.append(DownSample(channels[i], channels[i + 1]))
        
        # Bottleneck
        self.bottleneck = nn.ModuleList([
            LiquidDiffusionBlock(channels[-1], t_dim, expand_ratio, kernel_size, dropout),
            LiquidDiffusionBlock(channels[-1], t_dim, expand_ratio, kernel_size, dropout),
        ])
        
        # Decoder
        self.decoder_blocks = nn.ModuleList()
        self.upsamplers = nn.ModuleList()
        self.skip_fusions = nn.ModuleList()
        for i in range(self.num_stages - 1, -1, -1):
            if i < self.num_stages - 1:
                self.upsamplers.append(UpSample(channels[i + 1], channels[i]))
                self.skip_fusions.append(SkipFusion(channels[i], t_dim))
            stage = nn.ModuleList()
            for _ in range(blocks_per_stage[i]):
                stage.append(LiquidDiffusionBlock(
                    channels[i], t_dim, expand_ratio, kernel_size, dropout))
            self.decoder_blocks.append(stage)
        
        # Output head (initialized to zero for stable start)
        head_groups = min(32, channels[0])
        while channels[0] % head_groups != 0:
            head_groups -= 1
        self.head = nn.Sequential(
            nn.GroupNorm(head_groups, channels[0]),
            nn.SiLU(),
            nn.Conv2d(channels[0], in_channels, 3, padding=1),
        )
        nn.init.zeros_(self.head[-1].weight)
        nn.init.zeros_(self.head[-1].bias)

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [B, C, H, W] noisy image
            t: [B] timestep values in [0, 1]
        Returns:
            [B, C, H, W] predicted velocity
        """
        t_emb = self.time_embed(t)
        h = self.stem(x)
        
        # Encoder
        skips = []
        for i in range(self.num_stages):
            for block in self.encoder_blocks[i]:
                h = block(h, t_emb)
            skips.append(h)
            if i < self.num_stages - 1:
                h = self.downsamplers[i](h)
        
        # Bottleneck
        for block in self.bottleneck:
            h = block(h, t_emb)
        
        # Decoder
        up_idx = 0
        for dec_i in range(self.num_stages):
            stage_idx = self.num_stages - 1 - dec_i
            if dec_i > 0:
                h = self.upsamplers[up_idx](h)
                h = self.skip_fusions[up_idx](h, skips[stage_idx], t_emb)
                up_idx += 1
            for block in self.decoder_blocks[dec_i]:
                h = block(h, t_emb)
        
        return self.head(h)

    def count_params(self):
        total = sum(p.numel() for p in self.parameters())
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        return total, trainable


# =============================================================================
# 8. MODEL CONFIGS
# =============================================================================

def liquid_diffusion_tiny(**kwargs):
    """~23M params, 256px, fits ~6GB VRAM."""
    return LiquidDiffusionUNet(
        channels=[64, 128, 256], blocks_per_stage=[2, 2, 4],
        t_dim=256, expand_ratio=2.0, kernel_size=5, **kwargs)

def liquid_diffusion_small(**kwargs):
    """~69M params, 256px, fits ~10GB VRAM."""
    return LiquidDiffusionUNet(
        channels=[96, 192, 384], blocks_per_stage=[2, 3, 6],
        t_dim=384, expand_ratio=2.0, kernel_size=5, **kwargs)

def liquid_diffusion_base(**kwargs):
    """~154M params, 512px, fits ~16GB VRAM."""
    return LiquidDiffusionUNet(
        channels=[128, 256, 512], blocks_per_stage=[2, 4, 8],
        t_dim=512, expand_ratio=2.0, kernel_size=5, **kwargs)

def liquid_diffusion_large(**kwargs):
    """~120M params, 512px, needs ~24GB VRAM."""
    return LiquidDiffusionUNet(
        channels=[128, 256, 512, 768], blocks_per_stage=[2, 4, 8, 4],
        t_dim=512, expand_ratio=2.0, kernel_size=5, **kwargs)