PokéPixels1-9M (CPU)

A minimal diffusion model trained from scratch on CPU.

This project explores the lower limits of diffusion models:
How small and simple can a diffusion model be while still producing recognizable images?


Here are some "Fakemons" generated by the model: (64x64 Resolution)

image

image

🧠 Overview

TinyPokemonDiffusion is a lightweight DDPM-based generative model trained on Pokémon images.

Despite its small size and CPU-only training, the model learns:

  • Color distributions
  • Basic shapes
  • Early-stage object structure

⚙️ Specifications

Component Value
Parameters ~9M
Resolution 64x64
Training Device CPU (Ryzen 5 5600G)
Training Time ~5.5 hours
Dataset pokemon-blip-captions
Architecture Custom UNet
Precision float32

🧪 Features

  • Full DDPM implementation from scratch
  • Custom UNet with attention blocks
  • CPU-optimized training
  • Deterministic sampling (seed support)
  • Config-driven architecture

🖼️ Results

The model generates:

  • Coherent color palettes
  • Recognizable Pokémon-like silhouettes
  • Early-stage structure formation

Limitations:

  • Blurry outputs
  • Weak spatial consistency
  • No semantic understanding

THE INITIAL IDEA WAS A STUDENT U-NET FROM A TEACHER U-NET, BUT THIS WAS DISCONTINUED BECAUSE THE TEACHER WAS INITIALIZATED WITH RANDOM WEIGHTS, THAT WOULD KILL THE STUDENT LEARNING

🚀 Usage

Generate images


import torch
from pathlib import Path
from PIL import Image

# ===== CONFIG =====
CHECKPOINT = "model.pt"
N_IMAGES = 8
STEPS = 50
SEED = 42
OUT = "generated.png"

# ===== IMPORT MODEL =====
from train import StudentUNet, DDPMScheduler, Config

# ===== LOAD =====
torch.manual_seed(SEED)

ckpt = torch.load(CHECKPOINT, map_location="cpu")
cfg = ckpt.get("config", Config())

model = StudentUNet(cfg)
model.load_state_dict(ckpt["model_state"])
model.eval()

scheduler = DDPMScheduler(cfg.timesteps, cfg.beta_start, cfg.beta_end)

# ===== SAMPLING =====
@torch.no_grad()
def sample(model, scheduler, n, steps):
    x = torch.randn(n, 3, cfg.image_size, cfg.image_size)

    step_size = scheduler.T // steps
    timesteps = list(range(0, scheduler.T, step_size))[::-1]

    for t_val in timesteps:
        t = torch.full((n,), t_val, dtype=torch.long)

        noise_pred = model(x, t)

        if t_val > 0:
            ab = scheduler.alpha_bar[t_val]
            prev_t = max(t_val - step_size, 0)
            ab_prev = scheduler.alpha_bar[prev_t]

            beta_t = 1.0 - (ab / ab_prev)
            alpha_t = 1.0 - beta_t

            mean = (1.0 / alpha_t.sqrt()) * (
                x - (beta_t / (1.0 - ab).sqrt()) * noise_pred
            )

            x = mean + beta_t.sqrt() * torch.randn_like(x)
        else:
            x = scheduler.predict_x0(x, noise_pred, t)

    return x.clamp(-1, 1)

samples = sample(model, scheduler, N_IMAGES, STEPS)

# ===== SAVE =====
samples = (samples + 1) / 2
samples = (samples * 255).byte().permute(0, 2, 3, 1).numpy()

grid = Image.new("RGB", (cfg.image_size * N_IMAGES, cfg.image_size))

for i, img in enumerate(samples):
    grid.paste(Image.fromarray(img), (i * cfg.image_size, 0))

grid.save(OUT)

print(f"✅ Saved to {OUT}")

python generate.py \
  --checkpoint model.pt \
  --n_images 8 \
  --steps 50 \
  --seed 42

📁 Output

Generated images are saved as a horizontal grid:

outputs/generated.png

⚠️ Limitations

Unconditional model (no prompts)

Limited dataset diversity Early training stage No DDIM (yet)

🔬 Research Direction

This project demonstrates that:

Diffusion models can learn meaningful visual structure even at extremely small scales.

Future work:

Conditional generation (class-based) Text-to-image (v2.0) DDIM sampling Larger model variants 💡 Motivation

Most diffusion research focuses on scaling up.

This project explores the opposite direction:

What is the minimum viable diffusion model?

📜 License

MIT

🙌 Acknowledgments

Hugging Face datasets PyTorch The open-source AI community

⭐ If you like this project:

Give it a star and follow the evolution to v2.0(conditional) 🚀

Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train AxionLab-Co/PokePixels1-9M