File size: 10,432 Bytes
385f222 | 1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"gpuType":"T4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{},"source":["# LiquidGen: Liquid Neural Network Image Generator\n","\n","Attention-free diffusion model with CfC Liquid Neural Network dynamics.\n","\n","**Colab-optimized:**\n","- Gradient checkpointing (saves ~50% VRAM)\n","- Auto batch size (detects GPU, picks safe batch)\n","- Latent pre-caching (no VAE during training)\n","- Open SDXL VAE (no login)\n","\n","| Preset | Images | Size | Type |\n","|--------|--------|------|------|\n","| `cartoon` | ~2.5K | 181MB | Cartoon/anime |\n","| `flowers` | ~8K | 331MB | Flowers |\n","| `art_painting` | ~6K | 511MB | Art paintings |\n","| `wikiart` | ~105K | 1.6GB | WikiArt (use max_images!) |\n","\n","| Model | Params | 256px T4 batch | 512px T4 batch |\n","|-------|--------|---------------|----------------|\n","| small | 55M | 32 | 16 |\n","| base | 140M | 32 | 8 |\n","| large | 279M | 16 | 4 |"]},{"cell_type":"markdown","metadata":{},"source":["## 1. Install"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!pip install -q torch torchvision diffusers datasets accelerate huggingface_hub"]},{"cell_type":"markdown","metadata":{},"source":["## 2. Config"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["MODEL_SIZE = \"small\" # small/base/large\n","IMAGE_SIZE = 256 # 256 or 512\n","DATASET_PRESET = \"cartoon\" # see table above\n","MAX_IMAGES = 0 # 0=all, >0 to limit\n","BATCH_SIZE = 0 # 0 = AUTO (recommended!)\n","GRAD_ACCUM = 1\n","LEARNING_RATE = 1e-4\n","NUM_EPOCHS = 100\n","WARMUP_STEPS = 500\n","SAMPLE_EVERY = 500\n","SAMPLE_STEPS = 50\n","CFG_SCALE = 2.0\n","OUTPUT_DIR = \"/content/liquidgen\"\n","SAVE_EVERY = 2000\n","LOG_EVERY = 25\n","VAE_ID = \"madebyollin/sdxl-vae-fp16-fix\"\n","VAE_SCALE = 0.13025\n","LATENT_CH = 4\n","\n","import torch\n","if torch.cuda.is_available():\n"," g = torch.cuda.get_device_name(0)\n"," m = torch.cuda.get_device_properties(0).total_mem/1024**3\n"," print(f\"GPU: {g} ({m:.1f}GB)\")\n","else:\n"," print(\"No GPU! Runtime > Change runtime type > GPU\")"]},{"cell_type":"markdown","metadata":{},"source":["## 3. Download Code"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!wget -q -O model.py https://huggingface.co/asdf98/LiquidGen/resolve/main/model.py\n","!wget -q -O train.py https://huggingface.co/asdf98/LiquidGen/resolve/main/train.py\n","from model import LiquidGen, liquidgen_small, liquidgen_base, liquidgen_large\n","from train import (TrainConfig, DATASET_PRESETS, get_model_config,\n"," precache_latents, CachedLatentDataset, FlowMatchingScheduler,\n"," EMAModel, cosine_schedule, auto_batch_size)\n","print(\"Loaded!\")\n","for n,f in [(\"Small\",liquidgen_small),(\"Base\",liquidgen_base),(\"Large\",liquidgen_large)]:\n"," m=f(); print(f\" {n}: {m.count_params()/1e6:.0f}M\"); del m"]},{"cell_type":"markdown","metadata":{},"source":["## 4. Pre-Cache Latents"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import os, time\n","os.makedirs(f\"{OUTPUT_DIR}/samples\", exist_ok=True)\n","os.makedirs(f\"{OUTPUT_DIR}/checkpoints\", exist_ok=True)\n","\n","preset = DATASET_PRESETS[DATASET_PRESET]\n","NUM_CLASSES = preset[\"num_classes\"]\n","config = TrainConfig(\n"," model_size=MODEL_SIZE, num_classes=NUM_CLASSES,\n"," dataset_preset=DATASET_PRESET, image_size=IMAGE_SIZE,\n"," max_images=MAX_IMAGES, batch_size=BATCH_SIZE,\n"," gradient_accumulation_steps=GRAD_ACCUM,\n"," learning_rate=LEARNING_RATE, num_epochs=NUM_EPOCHS,\n"," warmup_steps=WARMUP_STEPS, output_dir=OUTPUT_DIR,\n"," save_every_n_steps=SAVE_EVERY, sample_every_n_steps=SAMPLE_EVERY,\n"," log_every_n_steps=LOG_EVERY, num_sample_steps=SAMPLE_STEPS,\n"," cfg_scale=CFG_SCALE, vae_id=VAE_ID,\n"," vae_scaling_factor=VAE_SCALE, latent_channels=LATENT_CH,\n"," gradient_checkpointing=True,\n",")\n","cache_path = precache_latents(config)"]},{"cell_type":"markdown","metadata":{},"source":["## 5. Train"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import torch.nn.functional as F\n","from torch.utils.data import DataLoader\n","from torch.amp import autocast, GradScaler\n","import math\n","\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","gpu_mem = torch.cuda.get_device_properties(0).total_mem/1024**3 if torch.cuda.is_available() else 8\n","\n","# Auto batch size\n","if config.batch_size <= 0:\n"," config.batch_size = auto_batch_size(MODEL_SIZE, IMAGE_SIZE, gpu_mem)\n"," print(f\"Auto batch size: {config.batch_size}\")\n","BS = config.batch_size\n","\n","train_ds = CachedLatentDataset(cache_path)\n","train_dl = DataLoader(train_ds, batch_size=BS, shuffle=True,\n"," num_workers=2, pin_memory=True, drop_last=True)\n","\n","mcfg = get_model_config(MODEL_SIZE, NUM_CLASSES)\n","mcfg[\"in_channels\"] = LATENT_CH\n","model = LiquidGen(**mcfg).to(device)\n","model.enable_gradient_checkpointing()\n","print(f\"LiquidGen-{MODEL_SIZE}: {model.count_params()/1e6:.1f}M (grad_ckpt=ON)\")\n","\n","opt = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)\n","total_steps = len(train_dl) * NUM_EPOCHS // GRAD_ACCUM\n","sched = cosine_schedule(opt, WARMUP_STEPS, total_steps)\n","ema = EMAModel(model, 0.9999)\n","scaler = GradScaler(\"cuda\")\n","fm = FlowMatchingScheduler()\n","lat_size = IMAGE_SIZE // 8\n","print(f\"Steps: {total_steps}, Batch: {BS}, Latent: [{BS},{LATENT_CH},{lat_size},{lat_size}]\")\n","if torch.cuda.is_available():\n"," print(f\"VRAM: {torch.cuda.memory_allocated()/1024**3:.1f}/{gpu_mem:.1f} GB\")\n","\n","gs=0; la=0; log_losses=[]; vae=None\n","print(\"\\nTraining!\\n\")\n","t0 = time.time()\n","\n","for epoch in range(NUM_EPOCHS):\n"," model.train(); et=time.time()\n"," for bi,(lats,lbls) in enumerate(train_dl):\n"," lats=lats.to(device)\n"," lbls=lbls.to(device) if NUM_CLASSES>0 else None\n"," t=fm.sample_timesteps(lats.shape[0],device)\n"," noise=torch.randn_like(lats)\n"," xt=fm.add_noise(lats,noise,t)\n"," vtgt=fm.get_velocity_target(lats,noise)\n"," with autocast(\"cuda\"):\n"," loss=F.mse_loss(model(xt,t,lbls),vtgt)/GRAD_ACCUM\n"," scaler.scale(loss).backward(); la+=loss.item()\n"," if (bi+1)%GRAD_ACCUM==0:\n"," scaler.unscale_(opt)\n"," gn=torch.nn.utils.clip_grad_norm_(model.parameters(),2.0)\n"," scaler.step(opt); scaler.update(); opt.zero_grad(); sched.step()\n"," ema.update(model); gs+=1\n"," if gs%LOG_EVERY==0:\n"," al=la/LOG_EVERY\n"," vram=torch.cuda.memory_allocated()/1024**3 if torch.cuda.is_available() else 0\n"," print(f\"step={gs:>5d} | ep={epoch} | loss={al:.4f} | gn={gn:.2f} | \"\n"," f\"lr={opt.param_groups[0]['lr']:.2e} | vram={vram:.1f}G\")\n"," log_losses.append(al); la=0\n"," if math.isnan(al): print(\"Diverged!\"); break\n"," if gs%SAMPLE_EVERY==0:\n"," if vae is None:\n"," from diffusers import AutoencoderKL\n"," vae=AutoencoderKL.from_pretrained(VAE_ID,torch_dtype=torch.float16).to(device).eval()\n"," for p in vae.parameters(): p.requires_grad_(False)\n"," ema.apply(model); model.eval()\n"," sl=torch.randint(0,max(1,NUM_CLASSES),(4,),device=device) if NUM_CLASSES>0 else None\n"," samp=fm.sample(model,(4,LATENT_CH,lat_size,lat_size),device,SAMPLE_STEPS,sl,CFG_SCALE)\n"," with torch.no_grad():\n"," imgs=((vae.decode(samp.half()/VAE_SCALE).sample+1)/2).clamp(0,1).float()\n"," from torchvision.utils import save_image\n"," save_image(imgs,f\"{OUTPUT_DIR}/samples/step_{gs:07d}.png\",nrow=2)\n"," print(f\" Saved samples\"); ema.restore(model); model.train()\n"," if gs%SAVE_EVERY==0:\n"," torch.save({\"model\":model.state_dict(),\"ema\":ema.shadow,\"step\":gs,\"cfg\":mcfg},\n"," f\"{OUTPUT_DIR}/checkpoints/step_{gs:07d}.pt\")\n"," print(f\"Epoch {epoch} | {time.time()-et:.0f}s\")\n","\n","torch.save({\"model\":model.state_dict(),\"ema\":ema.shadow,\"step\":gs,\"cfg\":mcfg},\n"," f\"{OUTPUT_DIR}/checkpoints/final.pt\")\n","print(f\"\\nDone! {gs} steps, {(time.time()-t0)/60:.1f}min\")"]},{"cell_type":"markdown","metadata":{},"source":["## 6. Loss"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import matplotlib.pyplot as plt\n","if log_losses:\n"," plt.figure(figsize=(10,4)); plt.plot(log_losses)\n"," plt.xlabel(f\"Steps (x{LOG_EVERY})\"); plt.ylabel(\"Loss\")\n"," plt.title(\"Training Loss\"); plt.grid(True,alpha=0.3)\n"," plt.savefig(f\"{OUTPUT_DIR}/loss.png\",dpi=150); plt.show()"]},{"cell_type":"markdown","metadata":{},"source":["## 7. Generate"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["if vae is None:\n"," from diffusers import AutoencoderKL\n"," vae=AutoencoderKL.from_pretrained(VAE_ID,torch_dtype=torch.float16).to(device).eval()\n"," for p in vae.parameters(): p.requires_grad_(False)\n","ema.apply(model); model.eval()\n","ls=IMAGE_SIZE//8\n","s=fm.sample(model,(8,LATENT_CH,ls,ls),device,50)\n","with torch.no_grad():\n"," i=((vae.decode(s.half()/VAE_SCALE).sample+1)/2).clamp(0,1).float()\n","from torchvision.utils import save_image\n","save_image(i,f\"{OUTPUT_DIR}/generated.png\",nrow=4)\n","ema.restore(model); print(\"Saved!\")"]},{"cell_type":"markdown","metadata":{},"source":["## 8. Display"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["from IPython.display import display\n","from PIL import Image\n","import glob\n","for f in sorted(glob.glob(f\"{OUTPUT_DIR}/samples/*.png\"))[-3:]:\n"," print(os.path.basename(f)); display(Image.open(f))\n","for f in sorted(glob.glob(f\"{OUTPUT_DIR}/gen*.png\")):\n"," print(os.path.basename(f)); display(Image.open(f))"]}]} |