| """ |
| LiquidDiffusion — Complete Test Suite |
| Tests model construction, forward/backward, training stability, and sampling. |
| Run: python test_model.py |
| """ |
| import sys |
| import math |
| import torch |
| import torch.nn.functional as F |
|
|
| |
| sys.path.insert(0, '.') |
|
|
| from liquid_diffusion.model import ( |
| LiquidDiffusionUNet, liquid_diffusion_tiny, |
| liquid_diffusion_small, liquid_diffusion_base |
| ) |
|
|
| print("=" * 70) |
| print("LiquidDiffusion: Novel Attention-Free Image Generation") |
| print("Based on Liquid Neural Networks (CfC) + Rectified Flow") |
| print("=" * 70) |
|
|
| all_passed = True |
|
|
| |
| print("\n--- Test 1: Model Construction & Parameter Count ---") |
| for name, factory in [("tiny", liquid_diffusion_tiny), ("small", liquid_diffusion_small), ("base", liquid_diffusion_base)]: |
| m = factory() |
| total, trainable = m.count_params() |
| print(f" {name:8s}: {total:>12,} params ({total/1e6:.1f}M)") |
| del m |
|
|
| |
| print("\n--- Test 2: Forward Pass (multiple resolutions) ---") |
| model = liquid_diffusion_tiny() |
| for res in [32, 64, 128]: |
| x = torch.randn(2, 3, res, res) |
| t = torch.rand(2) |
| out = model(x, t) |
| ok = out.shape == x.shape |
| print(f" {res}x{res}: {'OK' if ok else 'FAIL'} shape={out.shape}") |
| if not ok: all_passed = False |
|
|
| |
| print("\n--- Test 3: Backward Pass (gradient flow) ---") |
| model = liquid_diffusion_tiny() |
| x = torch.randn(2, 3, 64, 64) |
| t = torch.rand(2) |
| out = model(x, t) |
| loss = out.mean() |
| loss.backward() |
| num_params_with_grad = sum(1 for p in model.parameters() if p.grad is not None) |
| nan_grads = sum(1 for p in model.parameters() if p.grad is not None and torch.isnan(p.grad).any()) |
| print(f" Params with gradients: {num_params_with_grad}") |
| print(f" NaN gradients: {nan_grads}") |
| if nan_grads > 0: all_passed = False |
|
|
| |
| print("\n--- Test 4: Training Stability (20 steps, random data) ---") |
| model = liquid_diffusion_tiny() |
| optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01) |
| losses = [] |
| for step in range(20): |
| model.train() |
| x0 = torch.randn(4, 3, 64, 64) |
| x1 = torch.randn_like(x0) |
| t_val = torch.rand(4) |
| x_t = (1 - t_val[:, None, None, None]) * x0 + t_val[:, None, None, None] * x1 |
| v_target = x1 - x0 |
| v_pred = model(x_t, t_val) |
| loss = F.mse_loss(v_pred, v_target) |
| optimizer.zero_grad() |
| loss.backward() |
| gn = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| losses.append(loss.item()) |
| if step % 5 == 0: |
| print(f" Step {step:3d}: loss={loss.item():.4f}, grad_norm={gn.item():.4f}") |
|
|
| stable = all(not math.isnan(l) and not math.isinf(l) for l in losses) |
| not_exploding = max(losses) < 100 |
| print(f" Stable (no NaN/Inf): {'OK' if stable else 'FAIL'}") |
| print(f" Not exploding: {'OK' if not_exploding else 'FAIL'} (max={max(losses):.4f})") |
| if not stable or not not_exploding: all_passed = False |
|
|
| |
| print("\n--- Test 5: Sampling (10 Euler steps) ---") |
| model.eval() |
| with torch.no_grad(): |
| z = torch.randn(2, 3, 64, 64) |
| for i in range(10, 0, -1): |
| t_s = torch.full((2,), i / 10.0) |
| v = model(z, t_s) |
| z = z - v * 0.1 |
| z = z.clamp(-1, 1) |
| print(f" Shape: {z.shape}, range: [{z.min():.3f}, {z.max():.3f}]") |
|
|
| |
| print("\n--- Test 6: Timestep Sensitivity ---") |
| model.eval() |
| x = torch.randn(1, 3, 64, 64) |
| for t_val in [0.01, 0.25, 0.5, 0.75, 0.99]: |
| with torch.no_grad(): |
| out = model(x, torch.tensor([t_val])) |
| print(f" t={t_val:.2f}: mean={out.mean():.6f}, std={out.std():.6f}") |
|
|
| |
| print("\n--- Test 7: Architecture Properties ---") |
| m = liquid_diffusion_tiny() |
| total_blocks = (sum(len(s) for s in m.encoder_blocks) + len(m.bottleneck) + sum(len(s) for s in m.decoder_blocks)) |
| print(f" Attention layers: 0") |
| print(f" Sequential loops: 0") |
| print(f" CfC blocks: {total_blocks}") |
| print(f" Training objective: Rectified Flow (MSE velocity)") |
|
|
| |
| print("\n--- Test 8: VRAM Estimates (fp16 training) ---") |
| for name, factory, res, bs in [ |
| ("tiny 256px bs4", liquid_diffusion_tiny, 256, 4), |
| ("small 256px bs4", liquid_diffusion_small, 256, 4), |
| ("base 256px bs2", liquid_diffusion_base, 256, 2), |
| ("tiny 512px bs2", liquid_diffusion_tiny, 512, 2), |
| ]: |
| m = factory() |
| tp = sum(p.numel() for p in m.parameters()) |
| est = (tp * 2 + tp * 4 + tp * 8) / 1e9 + bs * 3 * res * res * 4 * len(m.channels) * max(m.channels) / 1e9 * 0.3 |
| print(f" {name:20s}: {tp/1e6:.1f}M params, ~{est:.1f}GB VRAM") |
| del m |
|
|
| print("\n" + "=" * 70) |
| print(f"ALL TESTS {'PASSED' if all_passed else 'SOME FAILURES'}") |
| print("=" * 70) |
|
|