Conditional Latent Diffusion Model for Retinal Future-State Synthesis
Trained model weights for predicting two-year follow-up retinal fundus images from baseline photographs and clinical metadata.
Model Description
This model adapts Stable Diffusion 1.5 for longitudinal retinal image prediction. It consists of two components:
Fine-tuned VAE (
vae_best.pt, 320 MB): SD 1.5 VAE encoder/decoder fine-tuned on retinal fundus images with L1 + SSIM + LPIPS + KL loss. Achieves SSIM 0.954 on reconstruction.Conditional U-Net (
diffusion_best.pt, 13 GB): 860M-parameter denoising U-Net with 15-channel input (4 noisy latent + 4 baseline latent + 7 clinical feature maps). Trained for 500 epochs with cosine LR schedule, EMA, and classifier-free guidance dropout.
Performance
| Metric | Value |
|---|---|
| SSIM | 0.762 |
| PSNR | 17.26 dB |
| LPIPS | 0.379 |
| FID | 107.28 |
Evaluated on 110 held-out test image pairs.
Qualitative Results
Each row shows a different test patient. Columns: baseline fundus, ground-truth follow-up, our prediction, Regression U-Net, and Pix2Pix. Our diffusion model generates sharper, more realistic retinal textures compared to deterministic baselines.
Training Dynamics
(a, b) Stage 1: VAE fine-tuning over 50 epochs reaching SSIM 0.954. (c-f) Stage 2: U-Net training over 500 epochs with cosine LR schedule and warmup.
Guidance Scale Sweep
SSIM peaks at guidance scale 7.5, while FID increases monotonically with stronger guidance, reflecting the fidelity-diversity tradeoff.
Usage
import torch
from diffusers import AutoencoderKL, UNet2DConditionModel
# Load VAE
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
vae_state = torch.load("vae_best.pt", map_location="cpu")
if "model_state_dict" in vae_state:
vae_state = vae_state["model_state_dict"]
vae.load_state_dict(vae_state, strict=False)
# Load U-Net (requires modified conv_in for 15 input channels)
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
# ... modify conv_in and load checkpoint
# See full inference code at the GitHub repository
Links
- Code: github.com/Usama1002/retinal-diffusion
- Dataset: huggingface.co/datasets/usama10/retinal-dr-longitudinal
Citation
@article{usama2026retinal,
title={Conditional Latent Diffusion for Predictive Retinal Fundus Image Synthesis from Baseline Imaging and Clinical Metadata},
author={Usama, Muhammad and Pazo, Emmanuel Eric and Li, Xiaorong and Liu, Juping},
journal={Computers in Biology and Medicine (under review)},
year={2026}
}
License
CC BY-NC 4.0. Non-commercial research use only.


