liangsu9988 commited on
Commit
5a7fea3
·
1 Parent(s): 513d68d

Update Blackwell MSA native API card

Browse files
Files changed (3) hide show
  1. CARD.md +12 -8
  2. README.md +12 -8
  3. VALIDATION.md +206 -0
CARD.md CHANGED
@@ -39,7 +39,7 @@ msa = get_kernel(
39
  |---|---|
40
  | `sparse_decode_atten_func` | Available. Blackwell paged BF16/FP16 single-token decode wrapper. |
41
  | `SparseDecodePagedAttentionWrapper` | Available. `plan(...).run(...)` wrapper for the same decode path. |
42
- | `build_k2q_csr` | Available. Torch CSR construction fallback. |
43
  | `SparseK2qCsrBuilderSm100` | Available compatibility class; `build()` delegates to `build_k2q_csr`. |
44
  | `Nvfp4QuantizedTensor` | Available metadata dataclass. |
45
  | `quantize_bf16_to_nvfp4_128x4` | Available when Transformer Engine NVFP4 support is installed. |
@@ -48,8 +48,8 @@ msa = get_kernel(
48
  | `swizzle_nvfp4_scale_to_128x4` | Available scale-layout helper. |
49
  | `nvfp4_global_scale_from_amax` | Available scale helper. |
50
  | `sparse_atten_func` | Available. Official CSR sparse prefill API backed by the Blackwell Triton BF16/FP16 prefill kernel. |
51
- | `sparse_atten_nvfp4_kv_func` | Available. NVFP4 KV compatibility path: dequantizes KV with 128x4 metadata, then calls Blackwell sparse prefill. |
52
- | `fp4_indexer_block_scores` | Available. Correctness-first FP4 block-score fallback returning the official `[Hq, ceil(max_seqlen_k/128), total_q]` score layout. |
53
 
54
  ### FlashRT Blackwell helper names
55
 
@@ -59,6 +59,7 @@ path:
59
  - `flash_decode_with_topk_idx`
60
  - `flash_decode_with_gqa_share_sparse`
61
  - `native_topk_from_scores`
 
62
  - `has_native_ops`
63
  - `naive_flash_decode_with_topk_idx`
64
  - `naive_flash_decode_with_gqa_share_sparse`
@@ -206,14 +207,17 @@ out = msa.flash_decode_with_gqa_share_sparse(
206
  This package contains:
207
 
208
  - native CUDA score-to-top-k helper;
209
- - Blackwell-validated Triton CUDA sparse decode and prefill attention;
 
 
 
210
  - MiniMaxAI/msa-compatible Python API layer for decode, prefill, CSR, NVFP4,
211
  and FP4 block-score helpers.
212
 
213
- The optimized SM100 CUTE prefill/indexer bodies are not claimed as ported here.
214
- For Blackwell, this package provides a validated Triton sparse prefill path and
215
- correctness-first compatibility fallbacks where the original API requires SM100
216
- FP4/NVFP4-specific machinery.
217
 
218
  Source provenance and validation details are documented in `SYNC.md` and
219
  `VALIDATION.md`.
 
39
  |---|---|
40
  | `sparse_decode_atten_func` | Available. Blackwell paged BF16/FP16 single-token decode wrapper. |
41
  | `SparseDecodePagedAttentionWrapper` | Available. `plan(...).run(...)` wrapper for the same decode path. |
42
+ | `build_k2q_csr` | Available. CSR construction helper for the official prefill API. |
43
  | `SparseK2qCsrBuilderSm100` | Available compatibility class; `build()` delegates to `build_k2q_csr`. |
44
  | `Nvfp4QuantizedTensor` | Available metadata dataclass. |
45
  | `quantize_bf16_to_nvfp4_128x4` | Available when Transformer Engine NVFP4 support is installed. |
 
48
  | `swizzle_nvfp4_scale_to_128x4` | Available scale-layout helper. |
49
  | `nvfp4_global_scale_from_amax` | Available scale helper. |
50
  | `sparse_atten_func` | Available. Official CSR sparse prefill API backed by the Blackwell Triton BF16/FP16 prefill kernel. |
51
+ | `sparse_atten_nvfp4_kv_func` | Available. Built artifacts use native CUDA swizzled NVFP4 -> BF16 dequantization, then call Blackwell sparse prefill. |
52
+ | `fp4_indexer_block_scores` | Available. Built artifacts use the native CUDA Blackwell block-score kernel and return the official `[Hq, ceil(max_seqlen_k/128), total_q]` score layout. |
53
 
54
  ### FlashRT Blackwell helper names
55
 
 
59
  - `flash_decode_with_topk_idx`
60
  - `flash_decode_with_gqa_share_sparse`
61
  - `native_topk_from_scores`
62
+ - `native_nvfp4_dequant_swizzled_to_bf16`
63
  - `has_native_ops`
64
  - `naive_flash_decode_with_topk_idx`
65
  - `naive_flash_decode_with_gqa_share_sparse`
 
207
  This package contains:
208
 
209
  - native CUDA score-to-top-k helper;
210
+ - native CUDA tensor-core sparse decode route for the MiniMax-M3 Blackwell shape;
211
+ - native CUDA FP4 block-score indexer;
212
+ - native CUDA swizzled NVFP4 -> BF16 dequantization for the W4A16 quality path;
213
+ - Blackwell-validated sparse prefill attention wrapper;
214
  - MiniMaxAI/msa-compatible Python API layer for decode, prefill, CSR, NVFP4,
215
  and FP4 block-score helpers.
216
 
217
+ When loaded from Hub built artifacts, the decode, FP4 indexer, and NVFP4
218
+ dequant hot paths use compiled CUDA ops. The source-tree mode keeps reference
219
+ paths so the API and correctness tests remain runnable before a wheel/shared
220
+ object has been built.
221
 
222
  Source provenance and validation details are documented in `SYNC.md` and
223
  `VALIDATION.md`.
README.md CHANGED
@@ -39,7 +39,7 @@ msa = get_kernel(
39
  |---|---|
40
  | `sparse_decode_atten_func` | Available. Blackwell paged BF16/FP16 single-token decode wrapper. |
41
  | `SparseDecodePagedAttentionWrapper` | Available. `plan(...).run(...)` wrapper for the same decode path. |
42
- | `build_k2q_csr` | Available. Torch CSR construction fallback. |
43
  | `SparseK2qCsrBuilderSm100` | Available compatibility class; `build()` delegates to `build_k2q_csr`. |
44
  | `Nvfp4QuantizedTensor` | Available metadata dataclass. |
45
  | `quantize_bf16_to_nvfp4_128x4` | Available when Transformer Engine NVFP4 support is installed. |
@@ -48,8 +48,8 @@ msa = get_kernel(
48
  | `swizzle_nvfp4_scale_to_128x4` | Available scale-layout helper. |
49
  | `nvfp4_global_scale_from_amax` | Available scale helper. |
50
  | `sparse_atten_func` | Available. Official CSR sparse prefill API backed by the Blackwell Triton BF16/FP16 prefill kernel. |
51
- | `sparse_atten_nvfp4_kv_func` | Available. NVFP4 KV compatibility path: dequantizes KV with 128x4 metadata, then calls Blackwell sparse prefill. |
52
- | `fp4_indexer_block_scores` | Available. Correctness-first FP4 block-score fallback returning the official `[Hq, ceil(max_seqlen_k/128), total_q]` score layout. |
53
 
54
  ### FlashRT Blackwell helper names
55
 
@@ -59,6 +59,7 @@ path:
59
  - `flash_decode_with_topk_idx`
60
  - `flash_decode_with_gqa_share_sparse`
61
  - `native_topk_from_scores`
 
62
  - `has_native_ops`
63
  - `naive_flash_decode_with_topk_idx`
64
  - `naive_flash_decode_with_gqa_share_sparse`
@@ -206,14 +207,17 @@ out = msa.flash_decode_with_gqa_share_sparse(
206
  This package contains:
207
 
208
  - native CUDA score-to-top-k helper;
209
- - Blackwell-validated Triton CUDA sparse decode and prefill attention;
 
 
 
210
  - MiniMaxAI/msa-compatible Python API layer for decode, prefill, CSR, NVFP4,
211
  and FP4 block-score helpers.
212
 
213
- The optimized SM100 CUTE prefill/indexer bodies are not claimed as ported here.
214
- For Blackwell, this package provides a validated Triton sparse prefill path and
215
- correctness-first compatibility fallbacks where the original API requires SM100
216
- FP4/NVFP4-specific machinery.
217
 
218
  Source provenance and validation details are documented in `SYNC.md` and
219
  `VALIDATION.md`.
 
39
  |---|---|
40
  | `sparse_decode_atten_func` | Available. Blackwell paged BF16/FP16 single-token decode wrapper. |
41
  | `SparseDecodePagedAttentionWrapper` | Available. `plan(...).run(...)` wrapper for the same decode path. |
42
+ | `build_k2q_csr` | Available. CSR construction helper for the official prefill API. |
43
  | `SparseK2qCsrBuilderSm100` | Available compatibility class; `build()` delegates to `build_k2q_csr`. |
44
  | `Nvfp4QuantizedTensor` | Available metadata dataclass. |
45
  | `quantize_bf16_to_nvfp4_128x4` | Available when Transformer Engine NVFP4 support is installed. |
 
48
  | `swizzle_nvfp4_scale_to_128x4` | Available scale-layout helper. |
49
  | `nvfp4_global_scale_from_amax` | Available scale helper. |
50
  | `sparse_atten_func` | Available. Official CSR sparse prefill API backed by the Blackwell Triton BF16/FP16 prefill kernel. |
51
+ | `sparse_atten_nvfp4_kv_func` | Available. Built artifacts use native CUDA swizzled NVFP4 -> BF16 dequantization, then call Blackwell sparse prefill. |
52
+ | `fp4_indexer_block_scores` | Available. Built artifacts use the native CUDA Blackwell block-score kernel and return the official `[Hq, ceil(max_seqlen_k/128), total_q]` score layout. |
53
 
54
  ### FlashRT Blackwell helper names
55
 
 
59
  - `flash_decode_with_topk_idx`
60
  - `flash_decode_with_gqa_share_sparse`
61
  - `native_topk_from_scores`
62
+ - `native_nvfp4_dequant_swizzled_to_bf16`
63
  - `has_native_ops`
64
  - `naive_flash_decode_with_topk_idx`
65
  - `naive_flash_decode_with_gqa_share_sparse`
 
207
  This package contains:
208
 
209
  - native CUDA score-to-top-k helper;
210
+ - native CUDA tensor-core sparse decode route for the MiniMax-M3 Blackwell shape;
211
+ - native CUDA FP4 block-score indexer;
212
+ - native CUDA swizzled NVFP4 -> BF16 dequantization for the W4A16 quality path;
213
+ - Blackwell-validated sparse prefill attention wrapper;
214
  - MiniMaxAI/msa-compatible Python API layer for decode, prefill, CSR, NVFP4,
215
  and FP4 block-score helpers.
216
 
217
+ When loaded from Hub built artifacts, the decode, FP4 indexer, and NVFP4
218
+ dequant hot paths use compiled CUDA ops. The source-tree mode keeps reference
219
+ paths so the API and correctness tests remain runnable before a wheel/shared
220
+ object has been built.
221
 
222
  Source provenance and validation details are documented in `SYNC.md` and
223
  `VALIDATION.md`.
VALIDATION.md ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Validation
2
+
3
+ ## Target
4
+
5
+ - Kernel family: MiniMax M3 sparse attention (MSA)
6
+ - Package: `flashrt/MiniMaxAI-msa-blackwell`
7
+ - HF Jobs package selector: `MiniMaxAI-msa-blackwell`
8
+ - Package version: v1 Blackwell native-helper package
9
+ - Target GPU family: Blackwell CUDA compute capability 12.x
10
+ - Validated GPU: SM121 / GB10 / DGX Spark
11
+ - Dtype: BF16 inputs with FP32 accumulation references
12
+ - Layout: paged KV cache
13
+ - Model path: FlashRT MiniMax-Spark runtime on DGX Spark / GB10
14
+
15
+ ## Correctness Gate
16
+
17
+ Run quick validation:
18
+
19
+ ```bash
20
+ PYTHONPATH=MiniMaxAI-msa-blackwell/torch-ext \
21
+ python MiniMaxAI-msa-blackwell/tests/test_msa_blackwell.py --quick
22
+ ```
23
+
24
+ Run full validation:
25
+
26
+ ```bash
27
+ PYTHONPATH=MiniMaxAI-msa-blackwell/torch-ext \
28
+ python MiniMaxAI-msa-blackwell/tests/test_msa_blackwell.py
29
+ ```
30
+
31
+ Run standalone long-context validation:
32
+
33
+ ```bash
34
+ PYTHONPATH=MiniMaxAI-msa-blackwell/torch-ext \
35
+ python MiniMaxAI-msa-blackwell/tests/test_msa_blackwell.py --long-context
36
+ ```
37
+
38
+ Expected full coverage:
39
+
40
+ | Area | Shapes | Reference | Required |
41
+ |---|---:|---|---|
42
+ | API surface | official `MiniMaxAI/msa` public names | `api_status.py` | all official root names exported; no unsupported public root API entries |
43
+ | Native CUDA top-k helper | heads 64, batch 1-2, blocks 1-256 | PyTorch top-k over valid blocks | exact set match |
44
+ | Decode sparse GQA attention | ctx 128, 2048, 4096, 32768 | paged FP32 PyTorch | cos >= 0.999, max_abs <= 5e-2 |
45
+ | Prefill sparse GQA attention | ctx 512, 4096 | paged causal FP32 PyTorch | cos >= 0.999, max_abs <= 5e-2 |
46
+ | Decode sparse GQA attention with sink | ctx 2048, 32768 | paged FP32 PyTorch | cos >= 0.999, max_abs <= 5e-2 |
47
+ | Official decode API wrapper | ctx 2048, 4096 | direct Blackwell decode kernel | cos = 1.0, max_abs = 0 |
48
+ | Official CSR prefill API wrapper | ctx 512, 2048 | direct Blackwell prefill kernel | cos = 1.0, max_abs = 0 under CSR-preserved block order |
49
+ | Official NVFP4 prefill API wrapper | ctx 512 BF16 dispatch path | `sparse_atten_func` | cos = 1.0, max_abs = 0 |
50
+ | Native CUDA NVFP4 dequant | rows/cols `(1,128)`, `(257,128)`, `(64,4096)` | Python NVFP4 reference | exact BF16 match |
51
+ | Official FP4 indexer API | tiny FP4 packed tensors; native artifact path when built | PyTorch block-score reference | returns official score layout |
52
+ | Decode lightning indexer | ctx 2048, 4096, 32768 | PyTorch blockmax top-k set | overlap >= 0.99 |
53
+ | Standalone long-context decode | ctx 65536, 131072 | paged FP32 PyTorch / direct kernel | cos >= 0.999; wrapper max_abs = 0 |
54
+ | Installed-artifact native long top-k | blocks 512, 1024 | PyTorch top-k over valid blocks | exact set match |
55
+
56
+ API surface validation:
57
+
58
+ ```bash
59
+ PYTHONPATH=MiniMaxAI-msa-blackwell/torch-ext \
60
+ python -m pytest MiniMaxAI-msa-blackwell/tests/test_api_surface.py -q
61
+ ```
62
+
63
+ The test tracks every official `MiniMaxAI/msa` public API name:
64
+
65
+ - `sparse_atten_func`
66
+ - `sparse_atten_nvfp4_kv_func`
67
+ - `sparse_decode_atten_func`
68
+ - `SparseDecodePagedAttentionWrapper`
69
+ - `fp4_indexer_block_scores`
70
+ - `build_k2q_csr`
71
+ - `SparseK2qCsrBuilderSm100`
72
+ - `Nvfp4QuantizedTensor`
73
+ - `quantize_bf16_to_nvfp4_128x4`
74
+ - `quantize_kv_bf16_to_nvfp4_128x4`
75
+ - `dequantize_nvfp4_128x4_to_bf16`
76
+ - `swizzle_nvfp4_scale_to_128x4`
77
+ - `nvfp4_global_scale_from_amax`
78
+
79
+ The root module exports every official public name. Decode, CSR prefill, NVFP4
80
+ prefill compatibility, FP4 block scoring, CSR, and NVFP4 helper names are all
81
+ callable. Hub built artifacts use compiled CUDA ops for the MiniMax-M3
82
+ Blackwell decode route, FP4 block-score indexer, and swizzled NVFP4 -> BF16
83
+ dequantization path. Source-tree mode keeps reference paths so the API remains
84
+ testable before the extension is built.
85
+
86
+ ## FlashRT Integration Note
87
+
88
+ FlashRT has validated the decode sparse path on SM121 over context lengths
89
+ 128 to 32768 with cosine similarity >= 0.999. The 32768 context length has
90
+ also been exercised in the FlashRT MiniMax-Spark model runtime on DGX Spark /
91
+ GB10, so it is the current end-to-end model validation boundary.
92
+
93
+ The standalone package kernel tests additionally cover 65536 and 131072
94
+ context lengths. These long-context rows validate the kernel and API wrapper
95
+ contract outside the full model runtime; they should not be described as
96
+ MiniMax-Spark end-to-end model validation until the full runtime path is rerun
97
+ at those lengths.
98
+
99
+ The same decode sparse path has also been exercised in FlashRT's MiniMax-Spark
100
+ model runtime on DGX Spark / GB10. That end-to-end validation is intentionally
101
+ kept as a FlashRT runtime validation item, while this Hub package exposes the
102
+ standalone kernel API for community use.
103
+
104
+ ## Native Helper Compile Smoke
105
+
106
+ Before HF Jobs publish, the native helper was compiled locally as a PyTorch
107
+ extension using the same source files:
108
+
109
+ - `torch-ext/torch_binding.cpp`
110
+ - `csrc/msa_topk_from_scores.cu`
111
+ - `csrc/msa_decode_attn.cu`
112
+ - `csrc/msa_decode_attn_mma.cu`
113
+ - `csrc/msa_indexer_block_scores.cu`
114
+ - `csrc/msa_nvfp4_dequant.cu`
115
+
116
+ Environment:
117
+
118
+ | Field | Value |
119
+ |---|---|
120
+ | GPU | NVIDIA GeForce RTX 5090 |
121
+ | PyTorch | 2.9.1+cu128 |
122
+ | nvcc | CUDA 13.0 |
123
+ | Target arch | sm_120 |
124
+
125
+ Result:
126
+
127
+ | Check | Shape | Reference | Verdict |
128
+ |---|---:|---|---|
129
+ | Native score -> top-k | heads 64, batch 1, blocks 256, topk 16 | PyTorch top-k set | PASS |
130
+ | Native FP4 block-score indexer | official `[Hq, blocks, total_q]` score layout | PyTorch block-score reference | PASS |
131
+ | Native NVFP4 swizzled -> BF16 dequant | rows/cols `(1,128)`, `(257,128)`, `(64,4096)` | Python NVFP4 reference | PASS |
132
+
133
+ ## Blackwell Package Validation
134
+
135
+ Remote Blackwell validation environment:
136
+
137
+ | Field | Value |
138
+ |---|---|
139
+ | Host | `spark-f517` |
140
+ | GPU | NVIDIA GB10 |
141
+ | Compute capability | 12.1 |
142
+ | Driver | 580.159.03 |
143
+ | Python | 3.12.3 |
144
+ | PyTorch | 2.12.0+cu130 |
145
+ | Triton | 3.7.0 |
146
+
147
+ Command:
148
+
149
+ ```bash
150
+ PY=/home/leadtek/jax/bin/python
151
+ PYTHONPATH=MiniMaxAI-msa-blackwell/torch-ext \
152
+ $PY MiniMaxAI-msa-blackwell/tests/test_msa_blackwell.py
153
+ ```
154
+
155
+ Result:
156
+
157
+ | Check | Shape | Cosine | Max abs / overlap | Verdict |
158
+ |---|---|---:|---:|---|
159
+ | Decode sparse GQA | ctx128_b1 | 0.999998 | 1.6032e-03 | PASS |
160
+ | Decode sparse GQA | ctx2048_b1 | 0.999996 | 4.9090e-04 | PASS |
161
+ | Decode sparse GQA | ctx2048_b2_sink | 0.999996 | 6.8302e-04 | PASS |
162
+ | Decode sparse GQA | ctx4096_b1 | 0.999996 | 4.5899e-04 | PASS |
163
+ | Decode sparse GQA | ctx4096_b2_mixed | 0.999996 | 7.3129e-04 | PASS |
164
+ | Decode sparse GQA | ctx32768_b1 | 0.999996 | 6.9451e-04 | PASS |
165
+ | Decode sparse GQA | ctx32768_b1_sink | 0.999996 | 5.6115e-04 | PASS |
166
+ | Decode sparse GQA | ctx65536_b1 | 0.999996 | 4.3470e-04 | PASS |
167
+ | Decode sparse GQA | ctx131072_b1 | 0.999996 | 7.1825e-04 | PASS |
168
+ | Decode top-k indexer | ctx2048 | n/a | overlap 1.000 | PASS |
169
+ | Decode top-k indexer | ctx4096 | n/a | overlap 1.000 | PASS |
170
+ | Decode top-k indexer | ctx32768 | n/a | overlap 1.000 | PASS |
171
+ | Decode top-k indexer | ctx65536 | n/a | overlap 1.000 | PASS |
172
+ | Decode top-k indexer | ctx131072 | n/a | overlap 1.000 | PASS |
173
+ | Official decode wrapper | ctx2048 | 1.000000 | 0.0000e+00 | PASS |
174
+ | Official decode wrapper | ctx4096 | 1.000000 | 0.0000e+00 | PASS |
175
+ | Official decode wrapper | ctx65536 | 1.000000 | 0.0000e+00 | PASS |
176
+ | Official decode wrapper | ctx131072 | 1.000000 | 0.0000e+00 | PASS |
177
+ | Native CUDA NVFP4 dequant | rows1_cols128 | 1.000000 | 0.0000e+00 | PASS |
178
+ | Native CUDA NVFP4 dequant | rows257_cols128 | 1.000000 | 0.0000e+00 | PASS |
179
+ | Native CUDA NVFP4 dequant | rows64_cols4096 | 1.000000 | 0.0000e+00 | PASS |
180
+
181
+ Installed-artifact native top-k validation on RTX 5090 / torch 2.11 / CUDA
182
+ 12.8:
183
+
184
+ | Context | Blocks | Overlap | Verdict |
185
+ |---:|---:|---:|---|
186
+ | 32768 | 256 | 1.000 | PASS |
187
+ | 65536 | 512 | 1.000 | PASS |
188
+ | 131072 | 1024 | 1.000 | PASS |
189
+
190
+ The warning `tl.make_block_ptr is deprecated` appears with Triton 3.7.0. It is
191
+ a deprecation warning, not a correctness failure.
192
+
193
+ ## Native Alignment Status
194
+
195
+ The upstream `MiniMaxAI/msa` package targets SM100. This Blackwell package
196
+ keeps the same public API surface where practical and provides native CUDA
197
+ implementations for the hot paths needed by the FlashRT MiniMax-Spark runtime:
198
+
199
+ - score-to-top-k sparse block selection;
200
+ - tensor-core sparse decode for the MiniMax-M3 Blackwell shape;
201
+ - FP4 block-score indexing;
202
+ - swizzled NVFP4 -> BF16 dequantization for the W4A16 path.
203
+
204
+ The CSR prefill wrapper remains part of the public compatibility surface and is
205
+ validated against the package reference path. Shape and parameter restrictions
206
+ are explicit errors rather than silent wrong results.