ST-EEGFormer β Safetensors Weights
Pre-converted safetensors weights for the ST-EEGFormer EEG foundation model, ready for use with steegformer-rs (pure-Rust inference on Burn 0.20) or any framework that supports safetensors.
Weights are converted from the official PyTorch .pth checkpoints published at LiuyinYang1101/STEEGFormer.
ST-EEGFormer won 1st Place in the NeurIPS 2025 EEG Foundation Challenge and was accepted at ICLR 2026.
Model Files
Encoder Only (for inference / embedding extraction)
| File | Variant | Params | Size | Layers | Heads | embed_dim |
|---|---|---|---|---|---|---|
ST-EEGFormer_small_encoder.safetensors |
Small | 25.6 M | 102 MB | 8 | 8 | 512 |
ST-EEGFormer_base_encoder.safetensors |
Base | 85.6 M | 342 MB | 12 | 12 | 768 |
ST-EEGFormer_large_encoder.safetensors |
Large | 303.0 M | 1,212 MB | 24 | 16 | 1024 |
ST-EEGFormer_largeV2_encoder.safetensors |
Large V2 | 303.1 M | 1,212 MB | 24 | 16 | 1024 |
Full MAE (encoder + decoder, for reconstruction / fine-tuning)
| File | Variant | Params | Size | Decoder dim | Decoder depth |
|---|---|---|---|---|---|
ST-EEGFormer_small_mae.safetensors |
Small | 33.1 M | 132 MB | 384 | 4 |
ST-EEGFormer_base_mae.safetensors |
Base | 111.5 M | 446 MB | 512 | 8 |
ST-EEGFormer_large_mae.safetensors |
Large | 329.1 M | 1,316 MB | 512 | 8 |
ST-EEGFormer_largeV2_mae.safetensors |
Large V2 | 329.3 M | 1,317 MB | 512 | 8 |
Config
| File | Description |
|---|---|
config.json |
Model hyperparameters for all variants |
Large V2 has undergone further pre-training on the HBN dataset for the NeurIPS 2025 EEG Foundation Challenge.
Quick Start β Rust
# Install
cargo add steegformer-rs
# Download weights
huggingface-cli download eugenehp/ST-EEGFormer \
ST-EEGFormer_small_encoder.safetensors \
config.json \
--local-dir weights/
# Run inference
cargo run --release --bin infer -- \
--config weights/config.json \
--weights weights/ST-EEGFormer_small_encoder.safetensors
Library API
use steegformer_rs::{STEEGFormerEncoder, ModelConfig, data};
use std::path::Path;
// Load model
let cfg = ModelConfig::small();
let (encoder, _ms) = STEEGFormerEncoder::<B>::load_from_config(
cfg,
Path::new("ST-EEGFormer_small_encoder.safetensors"),
device,
)?;
// Build input: 4 channels Γ 6 seconds @ 128 Hz
let channels = &["Fz", "C3", "C4", "Pz"];
let signal = vec![0.0f32; channels.len() * 768];
let batch = data::build_batch_named::<B>(signal, channels, 768, &device);
// Extract embeddings
let result = encoder.run_batch(&batch)?;
println!("Embedding shape: {:?}", result.shape); // [512]
Quick Start β Python
from safetensors.torch import load_file
# Load encoder weights
state_dict = load_file("ST-EEGFormer_small_encoder.safetensors")
# Build model and load
from models_mae_eeg import mae_vit_small_patch16
model = mae_vit_small_patch16()
model.load_state_dict(state_dict, strict=False)
model.eval()
Architecture
EEG signal (B, C, T) β up to 142 channels, 128 Hz, β€ 6s
β
βΌ
ββββββββββββββββββββββββββββββββββββββββ
β PatchEmbedEEG β
β Unfold β 16-sample patches β
β Linear(16, embed_dim) β
β β (B, num_patches Γ C, D) β
ββββββββββββββββββββββββββββββββββββββββ
β
+ Sinusoidal Temporal PE (fixed)
+ Learned Channel Embedding (nn.Embedding(145, D))
β
βΌ
ββββββββββββββββββββββββββββββββββββββββ
β [CLS] token prepend β
ββββββββββββββββββββββββββββββββββββββββ
β
βΌ
ββββββββββββββββββββββββββββββββββββββββ
β N Γ Transformer Encoder Block β
β Pre-norm: LN β MHSA β residual β
β LN β FFN β residual β
β (qkv_bias=True, GELU activation) β
ββββββββββββββββββββββββββββββββββββββββ
β
βΌ
ββββββββββββββββββββββββββββββββββββββββ
β LayerNorm β CLS token β
β β (B, embed_dim) embedding β
ββββββββββββββββββββββββββββββββββββββββ
MAE Pre-training (decoder, included in *_mae.safetensors)
Encoder output (25% of tokens)
β
βΌ
Linear(embed_dim β decoder_dim)
+ Insert mask tokens at masked positions
+ Decoder temporal/channel PE
β
βΌ
M Γ Decoder Transformer Blocks
β
βΌ
Linear(decoder_dim β patch_size)
β Reconstructed EEG patches
Numerical Parity (Rust vs Python)
Verified at every stage against the official PyTorch implementation:
| Stage | RMSE | Pearson r |
|---|---|---|
| Patch embedding | 0.000000 | 1.000000 |
| Channel embedding | 0.000000 | 1.000000 |
| Temporal encoding | 0.000000 | 1.000000 |
| After positional encoding | 0.000000 | 1.000000 |
| After transformer block 0 | 0.000004 | 1.000000 |
| Full encoder (8 blocks) | 0.000001 | 1.000000 |
Benchmarks
Platform: Apple M4 Pro, 64 GB RAM, macOS (arm64)
Inference Latency β ST-EEGFormer-Small (22ch Γ 768 samples)
| Backend | Mean | Min |
|---|---|---|
| Rust CPU (NdArray + Accelerate) | 608.4 ms | 601.4 ms |
| Python CPU (PyTorch 2.6) | 78.1 ms | 77.2 ms |
| Rust GPU (Burn wgpu + Metal) | 38.1 ms | 7.9 ms |
| Python MPS (PyTorch + Metal) | 19.2 ms | 19.0 ms |
Channel Scaling (T=768)
| Channels | Rust CPU | Python CPU | Rust GPU | Python MPS |
|---|---|---|---|---|
| 4 | 75.5 ms | 21.8 ms | 11.5 ms | 4.0 ms |
| 22 | 596.0 ms | 77.9 ms | 32.7 ms | 19.3 ms |
| 64 | 3853.2 ms | 301.9 ms | 119.4 ms | 90.1 ms |
Weight Key Format
Encoder keys
patch_embed.proj.weight [embed_dim, 16]
patch_embed.proj.bias [embed_dim]
cls_token [1, 1, embed_dim]
enc_channel_emd.channel_transformation.weight [145, embed_dim]
enc_temporal_emd.pe [1, 512, embed_dim]
blocks.{i}.norm1.weight [embed_dim]
blocks.{i}.norm1.bias [embed_dim]
blocks.{i}.attn.qkv.weight [3*embed_dim, embed_dim]
blocks.{i}.attn.qkv.bias [3*embed_dim]
blocks.{i}.attn.proj.weight [embed_dim, embed_dim]
blocks.{i}.attn.proj.bias [embed_dim]
blocks.{i}.norm2.weight [embed_dim]
blocks.{i}.norm2.bias [embed_dim]
blocks.{i}.mlp.fc1.weight [4*embed_dim, embed_dim]
blocks.{i}.mlp.fc1.bias [4*embed_dim]
blocks.{i}.mlp.fc2.weight [embed_dim, 4*embed_dim]
blocks.{i}.mlp.fc2.bias [embed_dim]
norm.weight [embed_dim]
norm.bias [embed_dim]
Decoder keys (MAE only)
decoder_embed.weight [dec_dim, embed_dim]
decoder_embed.bias [dec_dim]
mask_token [1, 1, dec_dim]
dec_channel_emd.channel_transformation.weight [145, dec_dim]
dec_temporal_emd.pe [1, 512, dec_dim]
decoder_blocks.{i}.* (same structure as encoder)
decoder_norm.weight [dec_dim]
decoder_norm.bias [dec_dim]
decoder_pred.weight [16, dec_dim]
decoder_pred.bias [16]
Conversion
These weights were converted from the official .pth files:
import torch
from safetensors.torch import save_file
ckpt = torch.load("checkpoint.pth", map_location="cpu", weights_only=False)
state_dict = ckpt["model"]
# Encoder only
encoder = {k: v.float().contiguous() for k, v in state_dict.items()
if any(k.startswith(p) for p in
["patch_embed.", "cls_token", "enc_", "blocks.", "norm."])}
save_file(encoder, "encoder.safetensors")
Or use the included conversion script:
python scripts/convert_to_safetensors.py --all
Citation
@inproceedings{yang2026_steegformer,
title={Are {EEG} Foundation Models Worth It? Comparative Evaluation
with Traditional Decoders in Diverse {BCI} Tasks},
author={Liuyin Yang and Qiang Sun and Ang Li and Marc M. Van Hulle},
booktitle={The Fourteenth International Conference on Learning Representations},
year={2026},
url={https://openreview.net/forum?id=5Xwm8e6vbh}
}
License
MIT β same as the original ST-EEGFormer release.
Links
| Rust crate | github.com/eugenehp/steegformer-rs |
| Original code | github.com/LiuyinYang1101/STEEGFormer |
| Original weights | GitHub Releases |
| Paper | OpenReview (ICLR 2026) |
| Burn framework | burn.dev |
- Downloads last month
- -