flash-attn-4-sm120
Flash Attention 4 (CuTe DSL) for SM120 / SM121 consumer Blackwell GPUs.
This is a downstream distribution of Dao-AILab/flash-attention that bundles six open upstream PRs targeting consumer Blackwell hardware (RTX 5090, RTX PRO 6000, DGX Spark GB10, SM121a). Once these PRs merge upstream, prefer the upstream flash-attn package; this bundle exists so SM120 users can use the improvements today.
Why this exists
flash-attn-4's CuTe DSL kernels work great on Hopper (SM90) and datacenter Blackwell (SM100). But SM120 (consumer Blackwell) is genuinely different hardware:
- No
tcgen05/ TMEM (so FA4's primary speed path doesn't apply) - No WGMMA (so the SM90 epilogue path doesn't apply)
- 99 KB shared memory capacity (vs 163 KB on SM80)
- Has TMA, but only single-CTA flavor
- Same SM80-era
mma.sync.aligned.m16n8k16for FP16/BF16 MMA
The PRs bundled here adapt FA4's kernels to these constraints β runtime-correct dispatch, SMEM-budget-aware tiling, paged KV that fits in 99 KB, TMA-with-warp-spec for the loaded path, and a couple of crash fixes that block dispatch entirely.
Bundled PRs
| PR | Title |
|---|---|
| #2336 | SM120 split-KV (FlashDecoding) with FP32 partial outputs |
| #2348 | SM120 kernel-level paged KV cache support |
| #2349 | SM120 TMA forward kernel with warp specialization |
| #2389 | SM80 / SM120 block-sparse forward attention support |
| #2439 | FA4 dropout (Philox, per-element, all arches) |
| #2484 | SM120 init-time runtime fix + GQA pack_gqa workaround |
Setup
Hardware
- NVIDIA SM120 / SM121 / SM121a (RTX 5090, RTX PRO 6000 Blackwell, DGX Spark GB10)
- Should also work on SM80 / SM90 / SM100 since the bundle inherits from upstream
flash-attn-4, but those paths are not the primary target
Software
- CUDA Toolkit 12.8 or newer (FA4 baseline requirement)
- PyTorch with CUDA support
nvidia-cutlass-dsl >= 4.4.1(auto-installed bykernels)einops,apache-tvm-ffi(auto-installed)
Installation via the kernels library (recommended)
pip install -U kernels
from kernels import get_kernel
flash_attn_4 = get_kernel("SecondNatureComputing/flash-attn-4-sm120")
kernels will download this repository, resolve dependencies, and make the package importable without any manual build step.
Direct use (alternative)
If you prefer not to use the kernels library, you can clone the repo and import the package directly:
git clone https://huggingface.co/SecondNatureComputing/flash-attn-4-sm120
import sys
sys.path.insert(0, "flash-attn-4-sm120/build/torch-cuda")
import importlib
flash_attn_4 = importlib.import_module("flash_attn_4_sm120") # or whatever you alias the dir to
The kernels.get_kernel(...) path is recommended since it handles caching and dependency resolution automatically.
Usage
Basic β non-causal MHA
import torch
from kernels import get_kernel
flash_attn_4 = get_kernel("SecondNatureComputing/flash-attn-4-sm120")
B, S, H, D = 1, 1024, 16, 128
q = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
k = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
v = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
out, _ = flash_attn_4.flash_attn_func(q, k, v, causal=False)
Causal GQA β Qwen / LLaMA family models
B, S, Hq, Hkv, D = 1, 2048, 16, 8, 128 # Qwen3-style GQA: Hq=16, Hkv=8
q = torch.randn(B, S, Hq, D, device="cuda", dtype=torch.bfloat16)
k = torch.randn(B, S, Hkv, D, device="cuda", dtype=torch.bfloat16)
v = torch.randn(B, S, Hkv, D, device="cuda", dtype=torch.bfloat16)
out, _ = flash_attn_4.flash_attn_func(q, k, v, causal=True)
Variable-length (production batched serving)
# Pack a batch of sequences with different lengths into a single flat tensor
seq_lens = [128, 256, 512]
total = sum(seq_lens)
cu_seqlens = torch.tensor([0] + list(__import__('itertools').accumulate(seq_lens)),
dtype=torch.int32, device="cuda")
q = torch.randn(total, Hq, D, device="cuda", dtype=torch.bfloat16)
k = torch.randn(total, Hkv, D, device="cuda", dtype=torch.bfloat16)
v = torch.randn(total, Hkv, D, device="cuda", dtype=torch.bfloat16)
out, _ = flash_attn_4.flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max(seq_lens),
max_seqlen_k=max(seq_lens),
causal=True,
)
Paged KV (vLLM / SGLang serving pattern)
out, _ = flash_attn_4.flash_attn_func(
q, k_paged, v_paged,
page_table=page_table,
seqused_k=actual_seq_lens,
max_seqlen_k=max_kv_len,
causal=True,
)
API
Two entry points exposed at the package root:
flash_attn_func(q, k, v, ...)β standard attention, fixed-length within a batchflash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, ...)β variable-length
Parameters supported beyond upstream main:
| Parameter | What it enables | PR |
|---|---|---|
page_table=... |
Paged KV cache | #2348 |
num_splits=... |
Split-KV / FlashDecoding | #2336 |
block_sparse_tensors=... |
Block-sparse attention | #2389 |
dropout_p=..., dropout_seed=... |
Per-element dropout | #2439 |
| (automatic) | TMA forward dispatch when viable | #2349 |
Tensor layout: (batch, seqlen, num_heads, head_dim), last dim contiguous, 16-byte aligned.
Validation
End-to-end on SM121a (DGX Spark GB10), bf16 + fp16, causal + non-causal, dense + varlen:
| Shape category | Configurations |
|---|---|
MHA (Hq = Hkv) |
D β {64, 128}, S β {128, 256, 512, 1024} |
| GQA Qwen3-style | Hq=16, Hkv=8, D=128 |
| GQA LLaMA3-style | Hq=32, Hkv=8, D=128 |
| MQA | Hq=4, Hkv=1, D=128 |
| Batched | B = 2 |
- Forward: 64 / 64 configurations pass β max diff β€ 0.0156 vs PyTorch f32 reference
- Backward: 40 / 40 configurations pass (dq, dk, dv all within 0.05 vs PyTorch f32 reference)
- Standalone install: validated via
kernels.get_kernel(...)from a clean Python venv with onlykernels,torch,nvidia-cutlass-dsl,apache-tvm-ffi,einops,quack-kernelsinstalled β noflash-attndependency required.
Performance
Patched HF FA4 vs vLLM's FA2 baseline on SM121a (DGX Spark), bf16, causal, Qwen3-style GQA Hq=16, Hkv=8, D=128, median of 30 iters after 5 warmups:
| Shape (B, S, Hq, Hkv, D) | HF FA4 (ms) | vLLM FA2 (ms) | FA4 / FA2 |
|---|---|---|---|
| (1, 128, 16, 8, 128) | 0.036 | 0.021 | 1.71x |
| (1, 512, 16, 8, 128) | 0.053 | 0.049 | 1.07x |
| (1, 1024, 16, 8, 128) | 0.106 | 0.102 | 1.04x |
| (1, 2048, 16, 8, 128) | 0.289 | 0.278 | 1.04x |
| (1, 4096, 16, 8, 128) | 0.976 | 0.886 | 1.10x |
| (2, 512, 16, 8, 128) | 0.075 | 0.069 | 1.09x |
| (4, 256, 16, 8, 128) | 0.059 | 0.049 | 1.19x |
| (8, 256, 16, 8, 128) | 0.109 | 0.104 | 1.05x |
At very short sequences (S = 128) FA4's dispatch overhead dominates (~70% slower than FA2). At realistic Qwen 3 prefill lengths (S = 512 to 4096) FA4 is within 4 to 10 percent of FA2. This is consistent with the SM120 hardware: no tcgen05 / TMEM means FA4's primary speed path doesn't apply, so it compiles down to roughly the same SM80 era mma.sync compute as FA2 with a small dispatch overhead. Use this kernel for the FA4 only features (paged KV, score_mod, block sparse, dropout); use FA2 if pure attention throughput is the only goal.
Known limitations
- GQA dispatches through the non-packed path on SM120 (PR #2484 workaround). Functionally correct on every GQA / MQA shape we tested. Throughput is within roughly 10% of fmha_v2 on the GQA shapes measured. Tracked upstream.
head_dim > 128is not supported on SM120 β the 99 KB SMEM budget cannot hold the Q tile. This affects models like Qwen3.5-9B (D=256) and Qwen3-Coder-Next (D=256). vLLM's existingfa_utils.pygate already routeshead_size > 128to FA2 on Blackwell; this kernel maintains that boundary.- Split-KV not supported on SM120 in this kernel variant. PR #2336 implements it but the bundle's
interface.pyclampsnum_splitsto 1 on SM12x. Decode workloads use a single split, which is consistent with how vLLM and SGLang configure SM120 today. - Dropout runs but spills registers at
tile_m=128, tile_n=128non-causal; the bundle'sinterface.pyfalls back totile_m=128, tile_n=64(ortile_m=64, tile_n=64forD > 64) whendropout_p > 0, which fixes the spill at a small throughput cost.
Hardware support outside SM120
The bundle inherits from upstream flash-attn-4's SM80 / SM90 / SM100 dispatch paths. Those should work the same as upstream main; the bundled PRs target SM120 specifically. We do not test SM80 / SM90 / SM100 β please open an issue if you find regressions.
License
BSD-3-Clause, inherited from Dao-AILab/flash-attention.
Credits
- Upstream: Dao-AILab/flash-attention β Tri Dao, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, and contributors
- SM120 PRs and bundle packaging: Blake Ledden, Second Nature Computing
- Hub packaging template: kernels-community/flash-attn4
Issues
For bundle-specific issues (the dispatch logic, validation gaps, packaging), open an issue on this HF repo. For kernel-level issues that exist upstream, file against Dao-AILab/flash-attention directly.
See also
CONFLICTS_LOG.mdβ detailed log of every conflict encountered while stacking the six PRs, with resolution and per-PR backport guidance
- Downloads last month
- -