flash-attn-4-sm120

Flash Attention 4 (CuTe DSL) for SM120 / SM121 consumer Blackwell GPUs.

This is a downstream distribution of Dao-AILab/flash-attention that bundles six open upstream PRs targeting consumer Blackwell hardware (RTX 5090, RTX PRO 6000, DGX Spark GB10, SM121a). Once these PRs merge upstream, prefer the upstream flash-attn package; this bundle exists so SM120 users can use the improvements today.

Why this exists

flash-attn-4's CuTe DSL kernels work great on Hopper (SM90) and datacenter Blackwell (SM100). But SM120 (consumer Blackwell) is genuinely different hardware:

  • No tcgen05 / TMEM (so FA4's primary speed path doesn't apply)
  • No WGMMA (so the SM90 epilogue path doesn't apply)
  • 99 KB shared memory capacity (vs 163 KB on SM80)
  • Has TMA, but only single-CTA flavor
  • Same SM80-era mma.sync.aligned.m16n8k16 for FP16/BF16 MMA

The PRs bundled here adapt FA4's kernels to these constraints β€” runtime-correct dispatch, SMEM-budget-aware tiling, paged KV that fits in 99 KB, TMA-with-warp-spec for the loaded path, and a couple of crash fixes that block dispatch entirely.

Bundled PRs

PR Title
#2336 SM120 split-KV (FlashDecoding) with FP32 partial outputs
#2348 SM120 kernel-level paged KV cache support
#2349 SM120 TMA forward kernel with warp specialization
#2389 SM80 / SM120 block-sparse forward attention support
#2439 FA4 dropout (Philox, per-element, all arches)
#2484 SM120 init-time runtime fix + GQA pack_gqa workaround

Setup

Hardware

  • NVIDIA SM120 / SM121 / SM121a (RTX 5090, RTX PRO 6000 Blackwell, DGX Spark GB10)
  • Should also work on SM80 / SM90 / SM100 since the bundle inherits from upstream flash-attn-4, but those paths are not the primary target

Software

  • CUDA Toolkit 12.8 or newer (FA4 baseline requirement)
  • PyTorch with CUDA support
  • nvidia-cutlass-dsl >= 4.4.1 (auto-installed by kernels)
  • einops, apache-tvm-ffi (auto-installed)

Installation via the kernels library (recommended)

pip install -U kernels
from kernels import get_kernel

flash_attn_4 = get_kernel("SecondNatureComputing/flash-attn-4-sm120")

kernels will download this repository, resolve dependencies, and make the package importable without any manual build step.

Direct use (alternative)

If you prefer not to use the kernels library, you can clone the repo and import the package directly:

git clone https://huggingface.co/SecondNatureComputing/flash-attn-4-sm120
import sys
sys.path.insert(0, "flash-attn-4-sm120/build/torch-cuda")
import importlib
flash_attn_4 = importlib.import_module("flash_attn_4_sm120")  # or whatever you alias the dir to

The kernels.get_kernel(...) path is recommended since it handles caching and dependency resolution automatically.

Usage

Basic β€” non-causal MHA

import torch
from kernels import get_kernel

flash_attn_4 = get_kernel("SecondNatureComputing/flash-attn-4-sm120")

B, S, H, D = 1, 1024, 16, 128
q = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
k = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
v = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)

out, _ = flash_attn_4.flash_attn_func(q, k, v, causal=False)

Causal GQA β€” Qwen / LLaMA family models

B, S, Hq, Hkv, D = 1, 2048, 16, 8, 128  # Qwen3-style GQA: Hq=16, Hkv=8

q = torch.randn(B, S, Hq,  D, device="cuda", dtype=torch.bfloat16)
k = torch.randn(B, S, Hkv, D, device="cuda", dtype=torch.bfloat16)
v = torch.randn(B, S, Hkv, D, device="cuda", dtype=torch.bfloat16)

out, _ = flash_attn_4.flash_attn_func(q, k, v, causal=True)

Variable-length (production batched serving)

# Pack a batch of sequences with different lengths into a single flat tensor
seq_lens = [128, 256, 512]
total = sum(seq_lens)
cu_seqlens = torch.tensor([0] + list(__import__('itertools').accumulate(seq_lens)),
                          dtype=torch.int32, device="cuda")

q = torch.randn(total, Hq,  D, device="cuda", dtype=torch.bfloat16)
k = torch.randn(total, Hkv, D, device="cuda", dtype=torch.bfloat16)
v = torch.randn(total, Hkv, D, device="cuda", dtype=torch.bfloat16)

out, _ = flash_attn_4.flash_attn_varlen_func(
    q, k, v,
    cu_seqlens_q=cu_seqlens,
    cu_seqlens_k=cu_seqlens,
    max_seqlen_q=max(seq_lens),
    max_seqlen_k=max(seq_lens),
    causal=True,
)

Paged KV (vLLM / SGLang serving pattern)

out, _ = flash_attn_4.flash_attn_func(
    q, k_paged, v_paged,
    page_table=page_table,
    seqused_k=actual_seq_lens,
    max_seqlen_k=max_kv_len,
    causal=True,
)

API

Two entry points exposed at the package root:

  • flash_attn_func(q, k, v, ...) β€” standard attention, fixed-length within a batch
  • flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, ...) β€” variable-length

Parameters supported beyond upstream main:

Parameter What it enables PR
page_table=... Paged KV cache #2348
num_splits=... Split-KV / FlashDecoding #2336
block_sparse_tensors=... Block-sparse attention #2389
dropout_p=..., dropout_seed=... Per-element dropout #2439
(automatic) TMA forward dispatch when viable #2349

Tensor layout: (batch, seqlen, num_heads, head_dim), last dim contiguous, 16-byte aligned.

Validation

End-to-end on SM121a (DGX Spark GB10), bf16 + fp16, causal + non-causal, dense + varlen:

Shape category Configurations
MHA (Hq = Hkv) D ∈ {64, 128}, S ∈ {128, 256, 512, 1024}
GQA Qwen3-style Hq=16, Hkv=8, D=128
GQA LLaMA3-style Hq=32, Hkv=8, D=128
MQA Hq=4, Hkv=1, D=128
Batched B = 2
  • Forward: 64 / 64 configurations pass β€” max diff ≀ 0.0156 vs PyTorch f32 reference
  • Backward: 40 / 40 configurations pass (dq, dk, dv all within 0.05 vs PyTorch f32 reference)
  • Standalone install: validated via kernels.get_kernel(...) from a clean Python venv with only kernels, torch, nvidia-cutlass-dsl, apache-tvm-ffi, einops, quack-kernels installed β€” no flash-attn dependency required.

Performance

Patched HF FA4 vs vLLM's FA2 baseline on SM121a (DGX Spark), bf16, causal, Qwen3-style GQA Hq=16, Hkv=8, D=128, median of 30 iters after 5 warmups:

Shape (B, S, Hq, Hkv, D) HF FA4 (ms) vLLM FA2 (ms) FA4 / FA2
(1, 128, 16, 8, 128) 0.036 0.021 1.71x
(1, 512, 16, 8, 128) 0.053 0.049 1.07x
(1, 1024, 16, 8, 128) 0.106 0.102 1.04x
(1, 2048, 16, 8, 128) 0.289 0.278 1.04x
(1, 4096, 16, 8, 128) 0.976 0.886 1.10x
(2, 512, 16, 8, 128) 0.075 0.069 1.09x
(4, 256, 16, 8, 128) 0.059 0.049 1.19x
(8, 256, 16, 8, 128) 0.109 0.104 1.05x

At very short sequences (S = 128) FA4's dispatch overhead dominates (~70% slower than FA2). At realistic Qwen 3 prefill lengths (S = 512 to 4096) FA4 is within 4 to 10 percent of FA2. This is consistent with the SM120 hardware: no tcgen05 / TMEM means FA4's primary speed path doesn't apply, so it compiles down to roughly the same SM80 era mma.sync compute as FA2 with a small dispatch overhead. Use this kernel for the FA4 only features (paged KV, score_mod, block sparse, dropout); use FA2 if pure attention throughput is the only goal.

Known limitations

  • GQA dispatches through the non-packed path on SM120 (PR #2484 workaround). Functionally correct on every GQA / MQA shape we tested. Throughput is within roughly 10% of fmha_v2 on the GQA shapes measured. Tracked upstream.
  • head_dim > 128 is not supported on SM120 β€” the 99 KB SMEM budget cannot hold the Q tile. This affects models like Qwen3.5-9B (D=256) and Qwen3-Coder-Next (D=256). vLLM's existing fa_utils.py gate already routes head_size > 128 to FA2 on Blackwell; this kernel maintains that boundary.
  • Split-KV not supported on SM120 in this kernel variant. PR #2336 implements it but the bundle's interface.py clamps num_splits to 1 on SM12x. Decode workloads use a single split, which is consistent with how vLLM and SGLang configure SM120 today.
  • Dropout runs but spills registers at tile_m=128, tile_n=128 non-causal; the bundle's interface.py falls back to tile_m=128, tile_n=64 (or tile_m=64, tile_n=64 for D > 64) when dropout_p > 0, which fixes the spill at a small throughput cost.

Hardware support outside SM120

The bundle inherits from upstream flash-attn-4's SM80 / SM90 / SM100 dispatch paths. Those should work the same as upstream main; the bundled PRs target SM120 specifically. We do not test SM80 / SM90 / SM100 β€” please open an issue if you find regressions.

License

BSD-3-Clause, inherited from Dao-AILab/flash-attention.

Credits

Issues

For bundle-specific issues (the dispatch logic, validation gaps, packaging), open an issue on this HF repo. For kernel-level issues that exist upstream, file against Dao-AILab/flash-attention directly.

See also

  • CONFLICTS_LOG.md β€” detailed log of every conflict encountered while stacking the six PRs, with resolution and per-PR backport guidance
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support