""" LiquidDiffusion — Complete Test Suite Tests model construction, forward/backward, training stability, and sampling. Run: python test_model.py """ import sys import math import torch import torch.nn.functional as F # Add parent directory to path sys.path.insert(0, '.') from liquid_diffusion.model import ( LiquidDiffusionUNet, liquid_diffusion_tiny, liquid_diffusion_small, liquid_diffusion_base ) print("=" * 70) print("LiquidDiffusion: Novel Attention-Free Image Generation") print("Based on Liquid Neural Networks (CfC) + Rectified Flow") print("=" * 70) all_passed = True # Test 1: Model construction print("\n--- Test 1: Model Construction & Parameter Count ---") for name, factory in [("tiny", liquid_diffusion_tiny), ("small", liquid_diffusion_small), ("base", liquid_diffusion_base)]: m = factory() total, trainable = m.count_params() print(f" {name:8s}: {total:>12,} params ({total/1e6:.1f}M)") del m # Test 2: Forward pass print("\n--- Test 2: Forward Pass (multiple resolutions) ---") model = liquid_diffusion_tiny() for res in [32, 64, 128]: x = torch.randn(2, 3, res, res) t = torch.rand(2) out = model(x, t) ok = out.shape == x.shape print(f" {res}x{res}: {'OK' if ok else 'FAIL'} shape={out.shape}") if not ok: all_passed = False # Test 3: Backward pass print("\n--- Test 3: Backward Pass (gradient flow) ---") model = liquid_diffusion_tiny() x = torch.randn(2, 3, 64, 64) t = torch.rand(2) out = model(x, t) loss = out.mean() loss.backward() num_params_with_grad = sum(1 for p in model.parameters() if p.grad is not None) nan_grads = sum(1 for p in model.parameters() if p.grad is not None and torch.isnan(p.grad).any()) print(f" Params with gradients: {num_params_with_grad}") print(f" NaN gradients: {nan_grads}") if nan_grads > 0: all_passed = False # Test 4: Training stability (20 steps) print("\n--- Test 4: Training Stability (20 steps, random data) ---") model = liquid_diffusion_tiny() optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01) losses = [] for step in range(20): model.train() x0 = torch.randn(4, 3, 64, 64) x1 = torch.randn_like(x0) t_val = torch.rand(4) x_t = (1 - t_val[:, None, None, None]) * x0 + t_val[:, None, None, None] * x1 v_target = x1 - x0 v_pred = model(x_t, t_val) loss = F.mse_loss(v_pred, v_target) optimizer.zero_grad() loss.backward() gn = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() losses.append(loss.item()) if step % 5 == 0: print(f" Step {step:3d}: loss={loss.item():.4f}, grad_norm={gn.item():.4f}") stable = all(not math.isnan(l) and not math.isinf(l) for l in losses) not_exploding = max(losses) < 100 print(f" Stable (no NaN/Inf): {'OK' if stable else 'FAIL'}") print(f" Not exploding: {'OK' if not_exploding else 'FAIL'} (max={max(losses):.4f})") if not stable or not not_exploding: all_passed = False # Test 5: Sampling print("\n--- Test 5: Sampling (10 Euler steps) ---") model.eval() with torch.no_grad(): z = torch.randn(2, 3, 64, 64) for i in range(10, 0, -1): t_s = torch.full((2,), i / 10.0) v = model(z, t_s) z = z - v * 0.1 z = z.clamp(-1, 1) print(f" Shape: {z.shape}, range: [{z.min():.3f}, {z.max():.3f}]") # Test 6: Timestep sensitivity print("\n--- Test 6: Timestep Sensitivity ---") model.eval() x = torch.randn(1, 3, 64, 64) for t_val in [0.01, 0.25, 0.5, 0.75, 0.99]: with torch.no_grad(): out = model(x, torch.tensor([t_val])) print(f" t={t_val:.2f}: mean={out.mean():.6f}, std={out.std():.6f}") # Test 7: Architecture properties print("\n--- Test 7: Architecture Properties ---") m = liquid_diffusion_tiny() total_blocks = (sum(len(s) for s in m.encoder_blocks) + len(m.bottleneck) + sum(len(s) for s in m.decoder_blocks)) print(f" Attention layers: 0") print(f" Sequential loops: 0") print(f" CfC blocks: {total_blocks}") print(f" Training objective: Rectified Flow (MSE velocity)") # Test 8: VRAM estimates print("\n--- Test 8: VRAM Estimates (fp16 training) ---") for name, factory, res, bs in [ ("tiny 256px bs4", liquid_diffusion_tiny, 256, 4), ("small 256px bs4", liquid_diffusion_small, 256, 4), ("base 256px bs2", liquid_diffusion_base, 256, 2), ("tiny 512px bs2", liquid_diffusion_tiny, 512, 2), ]: m = factory() tp = sum(p.numel() for p in m.parameters()) est = (tp * 2 + tp * 4 + tp * 8) / 1e9 + bs * 3 * res * res * 4 * len(m.channels) * max(m.channels) / 1e9 * 0.3 print(f" {name:20s}: {tp/1e6:.1f}M params, ~{est:.1f}GB VRAM") del m print("\n" + "=" * 70) print(f"ALL TESTS {'PASSED' if all_passed else 'SOME FAILURES'}") print("=" * 70)