Conditional Latent Diffusion Model for Retinal Future-State Synthesis

Trained model weights for predicting two-year follow-up retinal fundus images from baseline photographs and clinical metadata.

Model Description

This model adapts Stable Diffusion 1.5 for longitudinal retinal image prediction. It consists of two components:

  1. Fine-tuned VAE (vae_best.pt, 320 MB): SD 1.5 VAE encoder/decoder fine-tuned on retinal fundus images with L1 + SSIM + LPIPS + KL loss. Achieves SSIM 0.954 on reconstruction.

  2. Conditional U-Net (diffusion_best.pt, 13 GB): 860M-parameter denoising U-Net with 15-channel input (4 noisy latent + 4 baseline latent + 7 clinical feature maps). Trained for 500 epochs with cosine LR schedule, EMA, and classifier-free guidance dropout.

Performance

Metric Value
SSIM 0.762
PSNR 17.26 dB
LPIPS 0.379
FID 107.28

Evaluated on 110 held-out test image pairs.

Qualitative Results

Qualitative comparison

Each row shows a different test patient. Columns: baseline fundus, ground-truth follow-up, our prediction, Regression U-Net, and Pix2Pix. Our diffusion model generates sharper, more realistic retinal textures compared to deterministic baselines.

Training Dynamics

Training dynamics

(a, b) Stage 1: VAE fine-tuning over 50 epochs reaching SSIM 0.954. (c-f) Stage 2: U-Net training over 500 epochs with cosine LR schedule and warmup.

Guidance Scale Sweep

Guidance scale

SSIM peaks at guidance scale 7.5, while FID increases monotonically with stronger guidance, reflecting the fidelity-diversity tradeoff.

Usage

import torch
from diffusers import AutoencoderKL, UNet2DConditionModel

# Load VAE
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
vae_state = torch.load("vae_best.pt", map_location="cpu")
if "model_state_dict" in vae_state:
    vae_state = vae_state["model_state_dict"]
vae.load_state_dict(vae_state, strict=False)

# Load U-Net (requires modified conv_in for 15 input channels)
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
# ... modify conv_in and load checkpoint
# See full inference code at the GitHub repository

Links

Citation

@article{usama2026retinal,
  title={Conditional Latent Diffusion for Predictive Retinal Fundus Image Synthesis from Baseline Imaging and Clinical Metadata},
  author={Usama, Muhammad and Pazo, Emmanuel Eric and Li, Xiaorong and Liu, Juping},
  journal={Computers in Biology and Medicine (under review)},
  year={2026}
}

License

CC BY-NC 4.0. Non-commercial research use only.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train usama10/retinal-diffusion-model