Brain-JEPA (safetensors)

Pretrained weights for Brain-JEPA (NeurIPS 2024, Spotlight) converted to safetensors format for use with brainjepa-rs.

Model description

Brain-JEPA is a brain dynamics foundation model that maps parcellated fMRI time series (450 ROIs x T time points) to latent representations using a Vision Transformer with:

  • Brain gradient positioning for spatial (ROI) embeddings
  • Temporal patch embedding via 1D convolution along time
  • JEPA architecture (Joint Embedding Predictive Architecture)

The encoder is a 12-layer ViT-Base (768-dim, 12 heads, ~86M params) pretrained on UK Biobank resting-state fMRI for 300 epochs.

Files

File Description Shape info
brainjepa.safetensors All weights (encoder + predictor + target_encoder) 384 tensors, ~709 MB
gradient_mapping_450.csv Brain gradient coordinates for positional embeddings 450 rows x 30 columns

Weight key structure

Keys are prefixed by component (encoder., predictor., target_encoder.):

encoder.patch_embed.proj.weight          [768, 1, 1, 16]
encoder.blocks.{i}.norm1.weight          [768]
encoder.blocks.{i}.attn.qkv.weight       [2304, 768]
encoder.blocks.{i}.attn.proj.weight      [768, 768]
encoder.blocks.{i}.mlp.fc1.weight        [3072, 768]
encoder.blocks.{i}.mlp.fc2.weight        [768, 3072]
encoder.norm.weight                      [768]
...

For inference, use target_encoder.* keys (EMA-smoothed weights from pretraining).

Usage with brainjepa-rs (Rust)

# Install
git clone https://github.com/eugenehp/brainjepa-rs
cd brainjepa-rs

# Download weights from this repo
# Place brainjepa.safetensors and gradient_mapping_450.csv in data/

# Run inference (CPU)
cargo run --release --bin infer -- \
    --weights data/brainjepa.safetensors \
    --gradient data/gradient_mapping_450.csv \
    --input data/fmri_sample.safetensors

# Run inference (GPU, Metal/Vulkan)
cargo run --release --no-default-features --features wgpu --bin infer -- \
    --weights data/brainjepa.safetensors \
    --gradient data/gradient_mapping_450.csv \
    --input data/fmri_sample.safetensors

Rust library

use brainjepa_rs::{BrainJepaEncoder, ModelConfig, DataConfig};

let (encoder, _) = BrainJepaEncoder::<B>::from_weights(
    "data/brainjepa.safetensors",
    "data/gradient_mapping_450.csv",
    &ModelConfig::default(),
    &DataConfig::default(),
    &device,
)?;
let result = encoder.encode_safetensors("data/fmri.safetensors")?;
// result.embeddings: [4500, 768] float32

Usage with original Python code

These weights were converted from the original PyTorch checkpoint. To use with the original code:

import torch
from safetensors.torch import load_file

tensors = load_file("brainjepa.safetensors")
# Filter for target_encoder weights and strip prefix:
state_dict = {
    k.removeprefix("target_encoder."): v
    for k, v in tensors.items()
    if k.startswith("target_encoder.")
}
model.load_state_dict(state_dict)

Conversion

Weights were converted from the original PyTorch checkpoint using:

python scripts/convert_weights.py \
    --input jepa-ep300.pth.tar \
    --output brainjepa.safetensors

The conversion script strips the module. prefix from DDP-wrapped state dicts, converts all tensors to float32, and saves in safetensors format.

Benchmark

Tested on Mac Mini M4 Pro (14 cores, 64 GB). Input: [1, 1, 450, 160] (single sample, ViT-Base 86M params). Best-of-3 encode time.

Backend Encode vs PyTorch CPU
Rust โ€” NdArray + Rayon (CPU) 28,778 ms 0.06x
Rust โ€” NdArray + Accelerate (CPU) 21,092 ms 0.08x
Python โ€” PyTorch (CPU) 1,782 ms 1.0x
Python โ€” PyTorch MPS (GPU) 581 ms 3.1x
Rust โ€” wgpu f32 / Metal (GPU) 83 ms 21.5x
Rust โ€” wgpu f16 / Metal (GPU) 85 ms 21.0x

The Rust wgpu GPU backends are ~7x faster than PyTorch MPS and ~21x faster than PyTorch CPU.

benchmark

Architecture details

Parameter Value
Model ViT-Base
Embedding dim 768
Encoder depth 12 layers
Predictor depth 6 layers
Attention heads 12
Head dim 64
MLP ratio 4x (hidden=3072)
Patch size 16 (temporal)
Input size 450 ROIs x 160 time points
Output 4500 patches x 768 dims
Normalization LayerNorm (eps=1e-6)
Activation GELU
Pretraining 300 epochs on UK Biobank
Loss Smooth L1 (JEPA representation matching)
Optimizer AdamW (lr=1e-3, warmup=40 epochs, cosine decay)

Source

Original paper and code:

Zijian Dong, Ruilin Li, Yilei Wu, et al. Brain-JEPA: Brain Dynamics Foundation Model with Gradient Positioning and Spatiotemporal Masking. NeurIPS 2024 (Spotlight). arXiv:2409.19407

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Paper for eugenehp/brainjepa