datasysdev commited on
Commit
6df7db9
·
verified ·
1 Parent(s): 73b7176

Update model card with all-layer results

Browse files
Files changed (1) hide show
  1. README.md +124 -489
README.md CHANGED
@@ -1,188 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  # ann-sparseattention
2
 
3
- Train tiny per-layer "search projections" on a frozen LLM that replicate the
4
- attention's top-K preferences in a low-dimensional space, so we can swap dense
5
- quadratic attention for an off-the-shelf ANN index (FAISS HNSW) at inference
6
- and lose almost no model quality.
 
 
 
 
 
 
7
 
8
  ## Current status
9
 
10
- Research prototype. The trained projections work in a narrow 6-layer packed
11
- WikiText-103 pilot on `Qwen/Qwen3-4B-Instruct-2507`, but the runtime is still
12
- a correctness prototype. Treat reported numbers as preliminary until confidence
13
- intervals, downstream long-context tasks, and real baselines are run.
14
-
15
- Checkpoint artifacts and JSON eval outputs are mirrored on Hugging Face:
16
- [`datasysdev/ann-sparseattention`](https://huggingface.co/datasysdev/ann-sparseattention).
17
- Use `checkpoints_block_d128/search_step_1000.pt` there for the current clean
18
- block-causal result.
19
-
20
- **What's validated:**
21
- - 6-layer packed pilot on Qwen3-4B-Instruct-2507, layers
22
- `[4, 8, 12, 16, 20, 24]`, 4K context, 1K training steps.
23
- - `d_search=128` is the current recommended capacity from the packed capacity
24
- ablation: 3.93M trainable
25
- parameters, mass@K=128 of 0.503 vs 0.488 for the raw-QK exact-topK oracle,
26
- and -1.81% relative PPL gap at K=128 on the packed eval slice.
27
- - Block-causal packed masking is implemented. On the clean block-causal d128
28
- rerun, exact sparse attention is near parity with full attention
29
- (K=128: +0.07% PPL gap; K=256: +0.01%). The large negative PPL gaps from
30
- packed-with-leakage do not survive as a clean-methodology headline.
31
- - Capacity scaling is monotonic but saturating: d64 < d128 < d256 on mass@K,
32
- while d128 and d256 are effectively tied on final PPL.
33
- - Learned projections outperform raw-QK oracle mass in mid/late trained layers
34
- (L12-L24), while early layers remain harder.
35
-
36
- **Not yet validated (next iteration):**
37
- - Confidence intervals for the block-causal result over multiple seeds and
38
- larger eval slices.
39
- - Quest / RetrievalAttention baselines.
40
- - Long-context task quality (LongBench, RULER, needle-in-haystack).
41
- - 34-layer / whole-model substitution.
42
- - Wall-clock speedup vs. FlashAttention/SDPA — not measured.
43
- - KV-cache decode-mode integration.
44
- - GPU-resident ANN or fused gather-attention kernel.
45
-
46
- **Runtime caveat.** The current FAISS path is a correctness prototype: it
47
- builds a CPU index per forward pass and uses dense-style tensor expansion
48
- internally for the gather step. The compute-reduction numbers below are
49
- **algorithmic scoring reductions, not measured wall-clock speedups.** A
50
- production runtime requires a GPU-resident topk kernel or integration with
51
- paged/block-sparse attention kernels.
52
-
53
- ### d_search ablation (packed WikiText-103, K=128)
54
-
55
- The packed ablation trains the same 6 layers for 1K steps and evaluates all
56
- variants with the same packed eval pipeline. `raw_qk` is exact top-K over
57
- head-mean-aggregated native post-RoPE Q/K vectors; `learned` is exact top-K
58
- over trained search projections. mass@K is teacher-attention probability
59
- captured by the retrieved set.
60
-
61
- | d_search | Params | learned mass@K=128 | raw-QK oracle | learned / oracle | Final PPL gap |
62
- |---|---:|---:|---:|---:|---:|
63
- | 64 | 1.97M | 0.492 | 0.488 | 1.01x | +2.39% |
64
- | **128** | **3.93M** | **0.503** | **0.488** | **1.03x** | **-1.81%** |
65
- | 256 | 7.86M | 0.509 | 0.488 | 1.04x | -1.85% |
66
-
67
- d128 is the recommended default for this pilot: it captures almost all of the
68
- d256 quality with half the trainable parameters. d256 improves mass@K slightly
69
- but does not materially improve final PPL.
70
-
71
- PPL gap is the primary model-quality signal; mass@K is the more direct
72
- retrieval-quality signal when teacher attention is sharp. Recall@K is logged,
73
- but it is a weaker proxy because disagreement on near-zero-probability tail
74
- positions can look like low recall while preserving model output.
75
-
76
- Per-layer mass@K=128 for d128:
77
-
78
- | Layer | raw-QK oracle | learned d128 |
79
- |---|---:|---:|
80
- | 4 | 0.422 | 0.382 |
81
- | 8 | 0.518 | 0.421 |
82
- | 12 | 0.404 | 0.533 |
83
- | 16 | 0.475 | 0.481 |
84
- | 20 | 0.499 | 0.551 |
85
- | 24 | 0.614 | 0.648 |
86
-
87
- Early layers remain harder for learned retrieval; mid/late trained layers
88
- exceed raw-QK oracle mass.
89
-
90
- ### K-retrieve Pareto (packed d128, leakage-confounded)
91
-
92
- Exact top-K sweep for the recommended packed d128 checkpoint:
93
-
94
- ```bash
95
- python k_sweep.py \
96
- --ckpt /tmp/checkpoints_packed_d128/search_step_1000.pt \
97
- --K 128,256,512 \
98
- --no-use-faiss
99
- ```
100
-
101
- `PPL_full = 224.64` on this packed eval slice.
102
-
103
- | K | Recall@K | mass@K | PPL_ANN | PPL gap |
104
- |---|---:|---:|---:|---:|
105
- | 128 | 0.166 | 0.256 | 203.63 | -9.36% |
106
- | 256 | 0.233 | 0.318 | 207.06 | -7.83% |
107
- | 512 | 0.339 | 0.409 | 211.93 | -5.66% |
108
-
109
- This disambiguates the earlier FAISS high-K failure on the leaked packed
110
- pipeline: exact retrieval remains
111
- strongly negative at K=256/512, so the denoising pattern is present on this
112
- packed eval slice. This should not be used as a publication-strength denoising
113
- claim because packed examples can attend across document boundaries.
114
-
115
- A second exact sweep on the next 16 packed eval batches (`--skip-batches 16`)
116
- preserved the shape: K=128 -8.78%, K=256 -7.59%, K=512 -6.21%. This is still
117
- not a substitute for confidence intervals, but it reduces the chance that the
118
- large negative gap is a single-slice accident.
119
-
120
- ### Block-causal packed d128 (clean masking)
121
-
122
- Packed block-causal masking assigns each packed document a `segment_id`, resets
123
- `position_ids` at segment boundaries, and supplies a 4D additive mask so tokens
124
- can only attend causally within their own document. Retrieval, loss masking,
125
- mass@K, and recall@K use the same segment-causal eligibility mask.
126
-
127
- Clean d128 block-causal run:
128
-
129
- ```bash
130
- python train.py --config pilot_d128_block
131
- python k_sweep.py \
132
- --ckpt /tmp/checkpoints_block_d128/search_step_1000.pt \
133
- --K 128,256,512 \
134
- --no-use-faiss
135
- ```
136
-
137
- `PPL_full = 30.44` on the 16-batch clean eval slice.
138
-
139
- | K | Recall@K | mass@K | PPL_ANN | PPL gap |
140
- |---|---:|---:|---:|---:|
141
  | 128 | 0.744 | 0.787 | 30.47 | +0.07% |
142
  | 256 | 0.879 | 0.953 | 30.45 | +0.01% |
143
  | 512 | n/a | n/a | 30.45 | +0.01% |
144
 
145
- K=512 has no meaningful mass/recall average on this WikiText slice because
146
- almost no same-segment queries have 512 valid causal keys. The quality result
147
- is still useful: with filler slots masked out of the sparse-attention softmax,
148
- the block-causal exact path is effectively at full-attention parity. The clean
149
- result supports "quality-preserving sparse substitution" rather than the leaked
150
- pipeline's stronger denoising claim.
151
-
152
- Clean block-causal per-layer `compare_retrieval` at K=128:
153
-
154
- | Layer | raw-QK oracle mass | learned d128 mass |
155
- |---|---:|---:|
156
- | 4 | 0.956 | 0.950 |
157
- | 8 | 0.977 | 0.976 |
158
- | 12 | 0.970 | 0.977 |
159
- | 16 | 0.964 | 0.970 |
160
- | 20 | 0.970 | 0.983 |
161
- | 24 | 0.978 | 0.984 |
162
- | avg | 0.969 | 0.973 |
163
-
164
- This changes the per-layer interpretation from the leakage-confounded pilot:
165
- with segment isolation, early trained layers are not diffuse or uniquely hard.
166
- All six trained layers have high oracle mass, and learned projections match or
167
- slightly exceed raw-QK retrieval across the set. The deployment hypothesis for
168
- the next run is therefore "substitute all tested layers" rather than "keep early
169
- layers as full attention," pending a broader all-layer run.
170
-
171
- ### Quest-style page baseline (clean block-causal)
172
-
173
- `quest_sweep.py` implements a Quest-style min/max page selector for comparison:
174
- page size 16, native post-RoPE Q/K, same block-causal token eligibility mask,
175
- and the same sparse-attention gather path. This is a correctness baseline, not
176
- an optimized Quest runtime.
177
-
178
- ```bash
179
- python quest_sweep.py \
180
- --ckpt /tmp/checkpoints_block_d128/search_step_1000.pt \
181
- --K 128,256,512 \
182
- --page-size 16
183
- ```
184
-
185
- On the same 16-batch block-causal eval slice:
186
 
187
  | Method | K | Recall@K | mass@K | PPL | PPL gap |
188
  |---|---:|---:|---:|---:|---:|
@@ -191,330 +72,84 @@ On the same 16-batch block-causal eval slice:
191
  | learned search exact | 256 | 0.879 | 0.953 | 30.45 | +0.01% |
192
  | Quest-style page | 256 | 0.838 | 0.909 | 30.45 | +0.03% |
193
 
194
- Both methods are effectively full-attention parity on PPL. The learned search
195
- space recovers more teacher attention mass at the same token budget, especially
196
- at K=128, while Quest remains a strong non-trained heuristic baseline. This
197
- keeps the contribution narrow: learned projections improve retrieval fidelity
198
- and support standard ANN indexing; they do not yet show a clean PPL advantage
199
- over Quest on this slice.
200
-
201
- Paired 32-batch NLL evaluation gives a sharper comparison:
202
 
203
- | K | full PPL | learned PPL | Quest PPL | learned - Quest NLL delta (95% bootstrap CI) | Read |
204
- |---|---:|---:|---:|---:|---|
205
  | 128 | 28.03 | 28.07 | 28.01 | +0.00205 `[+0.00160, +0.00251]` | Quest slightly better |
206
  | 256 | 28.03 | 28.04 | 28.04 | -0.00005 `[-0.00029, +0.00018]` | statistical tie |
207
 
208
- So the current clean result is: learned search has higher teacher-attention
209
- mass, but PPL is either tied with Quest (K=256) or slightly worse (K=128) on
210
- this paired WikiText slice. The paper claim should be "retrieval-fidelity and
211
- ANN-compatibility advantages," not "PPL advantage over Quest."
212
 
213
- ### Clean FAISS-vs-exact check
214
 
215
- The first block-causal FAISS prototype used one global index followed by
216
- segment filtering, which produced pathological filler rates after filtering.
217
- The current FAISS path builds per-segment indexes when a 4D block-causal mask
218
- is present. With that fix, CPU FAISS/HNSW tracks exact learned search on the
219
- same 16-batch clean eval slice:
220
 
221
- | Method | K | PPL | PPL gap | FAISS filler rate |
222
- |---|---:|---:|---:|---:|
223
- | learned exact | 128 | 30.47 | +0.07% | n/a |
224
- | learned FAISS/HNSW | 128 | 30.47 | +0.09% | 0.447 |
225
- | learned exact | 256 | 30.45 | +0.01% | n/a |
226
- | learned FAISS/HNSW | 256 | 30.46 | +0.04% | 0.683 |
227
 
228
- The remaining filler rate is expected for short same-segment prefixes where
229
- fewer than K valid causal keys exist; filler slots are masked out of the sparse
230
- attention softmax. This demonstrates off-the-shelf ANN compatibility in the
231
- clean block-causal setting, but not production wall-clock speedup.
232
 
233
- ### Asymptotic scoring analysis
234
 
235
- `artifacts/scaling_analysis.md` gives a deterministic operation-count proxy
236
- for the per-query candidate scoring step. This is the cost of identifying
237
- which keys to attend to, before the sparse attention softmax and value
238
- multiply over the selected keys.
 
239
 
240
- Assumptions:
 
241
 
242
- - Full attention scoring: `N * d_head = N * 128`.
243
- - Quest-style page scoring: `(N / page_size) * 2 * d_head = N * 16`
244
- with `page_size=16`.
245
- - Learned HNSW scoring: `M * ef_search * log2(N) * d_search`
246
- with `M=32`, `ef_search=64`, and `d_search=128`.
247
 
248
- ![Candidate-scoring operations per query](artifacts/scaling_plot.svg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
- | Context | Full ops/query | Quest ops/query | Learned HNSW ops/query | Quest / learned |
251
- |---:|---:|---:|---:|---:|
252
- | 4K | 512,000 | 64,000 | 3,136,759 | 0.02x |
253
- | 8K | 1,024,000 | 128,000 | 3,398,903 | 0.04x |
254
- | 16K | 2,048,000 | 256,000 | 3,661,047 | 0.07x |
255
- | 32K | 4,096,000 | 512,000 | 3,923,191 | 0.13x |
256
- | 64K | 8,192,000 | 1,024,000 | 4,185,335 | 0.24x |
257
- | 128K | 16,384,000 | 2,048,000 | 4,447,479 | 0.46x |
258
- | 256K | 32,768,000 | 4,096,000 | 4,709,623 | 0.87x |
259
- | 512K | 65,536,000 | 8,192,000 | 4,971,767 | 1.65x |
260
- | 1M | 128,000,000 | 16,000,000 | 5,224,942 | 3.06x |
261
- | 2M | 256,000,000 | 32,000,000 | 5,487,086 | 5.83x |
262
- | 4M | 512,000,000 | 64,000,000 | 5,749,230 | 11.13x |
263
-
264
- Under these conservative HNSW constants, Quest is cheaper below the
265
- few-hundred-thousand-token regime and learned-projection scoring becomes
266
- cheaper beyond roughly 300K tokens. At 1M context, the operation-count proxy is
267
- about 3x in favor of learned projections. This supports the theoretical
268
- scaling claim only; production speed claims still require GPU-resident
269
- retrieval and KV-cache/decode integration.
270
-
271
- ### Dynamic-index proxy
272
-
273
- The current ANN wrapper is prefill-only (`use_cache=False`), so a true
274
- generation-time dynamic-index benchmark still requires cache integration. As a
275
- first capability proxy, `dynamic_index_proxy.py` splits clean block-causal eval
276
- sequences into a prefill prefix and decode-like suffix, then compares learned
277
- retrieval mass under two masks:
278
-
279
- - dynamic index: suffix queries can retrieve from all same-segment prior keys;
280
- - static index: suffix queries can retrieve from prefill keys plus a 256-token
281
- recent local suffix window, but not older suffix keys.
282
-
283
- On the clean d128 block-causal checkpoint, using K=128, prefill length 1024,
284
- local window 256, and 8 eval batches:
285
-
286
- | Setting | Teacher mass captured |
287
- |---|---:|
288
- | Dynamic proxy | 0.972 |
289
- | Static proxy | 0.928 |
290
- | Static teacher mass available | 0.954 |
291
- | Dynamic - static | +0.044 |
292
-
293
- Per-layer dynamic-minus-static mass ranges from +0.022 (L04) to +0.058 (L08).
294
- This does not establish task accuracy or decode latency, but it shows that a
295
- frozen prefill-plus-local index loses measurable teacher-attention mass on
296
- decode-like suffix queries. The raw result is in
297
- `artifacts/dynamic_proxy_8b.json`.
298
-
299
- ### Compute / quality knobs (FLOP-counted)
300
-
301
- `L = 4096`. Compute reduction is the attention scoring step, `≈ L / K`.
302
- These are FLOP estimates, not measured wall-clock — the FAISS path in this
303
- repo is a research prototype that does CPU index builds and GPU↔CPU
304
- transfers, so it is not the right thing to time. A GPU-resident topk
305
- kernel is the natural next step.
306
-
307
- | K | PPL gap | Attention scoring reduction |
308
- |---|---|---|
309
- | 512 | -5.66% (exact top-K over learned search space) | ~8x |
310
- | 256 | -7.83% (exact top-K over learned search space) | ~16x |
311
- | 128 | -9.36% exact; -1.81% FAISS/training eval | ~32x |
312
- | 64 | +0.46% | ~64x |
313
- | 32 | +0.03% | ~128x |
314
- | 16 | +5.63% | ~256x |
315
-
316
- Eval scope for the d_search table: 16 packed validation batches at 4K context
317
- for PPL/recall during training, and 12 packed batches for `compare_retrieval`
318
- mass@K. Numbers should be read as "what we observed on this slice", not
319
- population-level estimates.
320
-
321
- ### Caveats / what's next
322
-
323
- A few things the pilot does not yet establish, and that the next iteration
324
- will:
325
-
326
- - **Packing**: the d_search ablation table is still from the packed
327
- leakage-confounded run and is best read as a capacity comparison. The clean
328
- block-causal d128 rerun removes cross-document leakage and should be used for
329
- quality claims.
330
- - **Exact-topK oracle**: the obvious follow-up is a four-way Pareto —
331
- full attention vs. exact top-K (true `QK^T` argmax-K, then attention) vs.
332
- search-topK (our projections, exact distance) vs. search-ANN (FAISS HNSW).
333
- That separates "denoising from any sparsity" from "denoising from learned
334
- projections."
335
- - **Wall-clock**: the compute-reduction table above is FLOP-counted. The
336
- FAISS path here is a research prototype (CPU index per forward, GPU↔CPU
337
- transfer) and is the wrong thing to time. A GPU-resident topk kernel is
338
- the next-step engineering.
339
- - **34-layer headline**: was queued and the VM was reclaimed before launch.
340
- Config is wired (`make_headline_config()`); rerun is a single command on
341
- any B200/H100/H200.
342
-
343
- The recall@K and mass@K reported here come from a 12-batch eval slice, not
344
- a population-level estimate. Confidence intervals and downstream tasks
345
- (LongBench / RULER / needle-in-haystack) are the natural next evals.
346
-
347
- ### Headline run (queued)
348
-
349
- 34 layers (every layer except 0 and 35), 8K context, 6K steps,
350
- ~4-5h on a single B200. Tests whether the technique generalizes from a
351
- 6-layer subset to broad layer coverage. Checkpoints will be mirrored at
352
- [`datasysdev/ann-sparseattention`](https://huggingface.co/datasysdev/ann-sparseattention).
353
-
354
- ## Relation to RetrievalAttention
355
-
356
- The closest prior work is RetrievalAttention (Liu et al., 2024). They show
357
- that **vanilla ANN over the model's native Q and K vectors fails** because
358
- Q and K live in mismatched distributions — they were never trained to be
359
- each other's nearest neighbors, only to score correctly via the dot
360
- product. Their fix is at *index time*: an attention-aware graph
361
- construction (RoarGraph-style) that compensates for the Q/K out-of-
362
- distribution problem at search time.
363
-
364
- This work attacks the same problem from the opposite direction. Instead of
365
- patching the index over hostile vectors, we **train a tiny shared
366
- low-dimensional projection** (`W_Qs, W_Ks → R^128` in the recommended pilot)
367
- so that `q_search` and `k_search` *do* live in the same distribution by construction. Off-the-
368
- shelf FAISS HNSW with default parameters is then sufficient — there is no
369
- attention-aware index trick.
370
-
371
- | | Search space | Index | Trainable |
372
- |---|---|---|---|
373
- | Raw Q/K + vanilla ANN | original Q/K | off-the-shelf | no — fails (Q/K OOD) |
374
- | RetrievalAttention | original Q/K | attention-aware graph | no |
375
- | **This work** | **learned Q\_s / K\_s** | **off-the-shelf** | **yes (~2-11M params)** |
376
-
377
- The contribution claim: *eliminate the Q/K mismatch at index-build time
378
- via distillation, instead of patching it at search time.* The clean
379
- experiment to validate this — vanilla FAISS over raw Q/K vs. vanilla
380
- FAISS over learned Q\_s/K\_s vs. exact teacher top-K, all at the same K —
381
- is the next planned run. The current pilot establishes that the learned
382
- projections retrieve attention-relevant keys; the comparison run isolates
383
- how much of that came from the projection vs. the ANN approximation.
384
-
385
- ## How it works
386
-
387
- For each full-attention layer `i` we train two linear projections
388
- `W_Qs^i, W_Ks^i ∈ R^{d_model × d_search}` (recommended pilot: d_search=128),
389
- so that for any
390
- hidden state `h`,
391
-
392
- ```
393
- q_search = W_Qs^i h k_search = W_Ks^i h
394
- softmax(q_search · k_search^T) ranks the same keys as
395
- softmax(QK^T / √d_head) (the teacher's attention)
396
- ```
397
-
398
- Two losses, summed across layers:
399
-
400
- - **InfoNCE** with teacher-derived positives (top-`K_pos` keys from the
401
- teacher's attention serve as positives for each query).
402
- - **KL(teacher ‖ student)** on the full attention distribution.
403
-
404
- At inference, we monkey-patch each trained layer's attention forward to:
405
-
406
- 1. Compute `q_search`, `k_search` from the same hidden state.
407
- 2. Build a per-batch FAISS HNSW index over `k_search` (default params).
408
- 3. Retrieve top-`K_retrieve` positions (causal-respecting) per query.
409
- 4. Run standard attention restricted to those `K_retrieve` keys.
410
-
411
- The base model's parameters are never touched. The recommended d128 pilot
412
- trains 3.93M parameters total.
413
-
414
- ## Repo layout
415
-
416
- ```
417
- config.py Run config (pilot defaults; make_headline_config() for follow-up)
418
- model.py SearchProjection, FrozenForwardCapture (with QK reconstruction
419
- trick: capture (Q, K) post-RoPE while the forward stays in FA),
420
- contrastive + KL distillation losses
421
- data.py Long-context dataloader (packing off by default to avoid
422
- cross-segment attention leakage; pin_memory, prefetch)
423
- inference.py ANN-substituted attention (exact top-K for analysis;
424
- CPU-FAISS HNSW prototype path — not a deployable kernel)
425
- eval.py recall@K curve, mass@K curve, full-vs-ANN PPL,
426
- MoE router stability
427
- train.py Training loop, Liger setup, FA-3→FA-2→SDPA→eager fallback,
428
- base-model freeze + drift check, auto-resume from latest ckpt
429
- tests/ QK reconstruction verification + 50-step smoke test
430
- ```
431
-
432
- ## Quick start
433
-
434
- ```bash
435
- pip install -r requirements.txt
436
- export WANDB_API_KEY=<key> # only — never check it in
437
- export HF_TOKEN=<token> # for faster Hub downloads
438
-
439
- # Pre-launch checks
440
- python -c "from transformers import AutoConfig; \
441
- print(AutoConfig.from_pretrained('Qwen/Qwen3-4B-Instruct-2507'))"
442
- python tests/test_qk_reconstruction.py
443
- python tests/smoke_test.py
444
-
445
- # Packed d_search ablation
446
- bash scripts/run_packed_ablation.sh
447
-
448
- # Default clean pilot (packing off; data-sparse on WikiText articles)
449
- python train.py --config pilot_d64_clean
450
- ```
451
-
452
- ## Configuration
453
-
454
- The default `Config` is the 1-day pilot:
455
-
456
- | Knob | Pilot | Headline |
457
- |---|---|---|
458
- | `seq_len` | 4096 | 8192 |
459
- | `batch_size` | 8 | 8 |
460
- | `total_steps` | 1000 | 6000 |
461
- | layers trained | 6 (`[4,8,12,16,20,24]`) | 34 (`range(36)` minus reserved `[0, 35]`) |
462
- | trainable params | 1.97M at d64; 3.93M at d128 | 11.1M at d64 |
463
- | `d_search` | 64 default; d128 recommended from ablation | 64 default |
464
- | `K_retrieve_eval` | 128 | 128 |
465
-
466
- Pilot is the proof-of-concept; headline trains every attention layer except
467
- the first (raw-embedding-adjacent) and last (output-logits-adjacent), which is
468
- the deployment-relevant claim that the technique scales to dense application.
469
-
470
- Use `make_pilot_d128_packed_config()` to reproduce the current recommended
471
- pilot, or `make_headline_config()` for the broader 34-layer run.
472
-
473
- ## Performance choices
474
-
475
- - `attn_implementation` resolves at load time as
476
- `flash_attention_3 → flash_attention_2 → sdpa → eager`. On B200 with no
477
- flash-attn package installed, SDPA wins — its built-in flash backend is
478
- ~80-90% of FA-2's throughput with zero build dependency.
479
- - Liger kernels applied via `apply_liger_kernel_to_qwen3` (RMSNorm, SwiGLU,
480
- RoPE fused — typically 30-50% faster forward).
481
- - The QK-reconstruction trick keeps SDPA/FA fast on the trained layers:
482
- we monkey-patch them to capture `(Q, K)` post-RoPE, then reconstruct
483
- `softmax(QK^T/√d)` ourselves *after* the forward returns. The forward
484
- never sets `output_attentions=True` (which would force eager).
485
- - `torch.compile(search_module, mode="max-autotune")` on the search
486
- projections; base model uncompiled (works but flaky for novel architectures).
487
- - bf16 throughout; loss math cast to fp32 for numerical stability of softmax.
488
-
489
- ## Verifying the QK reconstruction
490
-
491
- The post-RoPE Q/K capture must match what the model's eager attention computes
492
- or distillation supervision is wrong. The test asserts top-32 agreement
493
- > 99% per layer:
494
-
495
- ```bash
496
- python tests/test_qk_reconstruction.py --model Qwen/Qwen3-4B-Instruct-2507
497
- # layer 0: PASS max|Δ|=2.54e-02 top-32 agree=0.9963
498
- # layer 1: PASS max|Δ|=5.27e-02 top-32 agree=0.9941
499
- # ...
500
- # QK reconstruction verified.
501
- ```
502
-
503
- The bf16 max-abs differences (~0.05) are just numerical noise; the
504
- *ranking* of attention positions matches.
505
-
506
- ## Reproducing the pilot
507
-
508
- ```bash
509
- git clone git@github.com:unixsysdev/ann-sparseattention.git
510
- cd ann-sparseattention
511
- pip install -r requirements.txt
512
- python train.py --config pilot_d128_packed
513
- ```
514
-
515
- A single H100/H200/B200 + 8GB GPU RAM for the 4B model + ~10GB for activations
516
- at 4K context, batch 8.
517
-
518
- ## License
519
-
520
- MIT.
 
1
+ ---
2
+ license: mit
3
+ library_name: transformers
4
+ tags:
5
+ - sparse-attention
6
+ - approximate-nearest-neighbors
7
+ - faiss
8
+ - qwen3
9
+ - long-context
10
+ base_model: Qwen/Qwen3-4B-Instruct-2507
11
+ ---
12
+
13
  # ann-sparseattention
14
 
15
+ This repository mirrors checkpoints and result artifacts for
16
+ [`unixsysdev/ann-sparseattention`](https://github.com/unixsysdev/ann-sparseattention).
17
+
18
+ The method trains tiny per-layer query/key "search projections" on a frozen
19
+ LLM so that attention-relevant keys are nearest neighbors in a low-dimensional
20
+ search space. At inference, candidate selection can be done with standard ANN
21
+ machinery such as FAISS HNSW, then ordinary attention is computed over the
22
+ retrieved native KV vectors.
23
+
24
+ This is a research prototype, not a production sparse-attention runtime.
25
 
26
  ## Current status
27
 
28
+ Validated clean pilot:
29
+
30
+ - Base model: `Qwen/Qwen3-4B-Instruct-2507`
31
+ - Dataset: WikiText-103
32
+ - Context: 4096 tokens
33
+ - Clean masking: packed block-causal segment isolation
34
+ - Recommended clean checkpoint:
35
+ `checkpoints_block_d128/search_step_1000.pt`
36
+ - Trained layers: `[4, 8, 12, 16, 20, 24]`
37
+ - `d_search=128`
38
+ - Trainable parameters: 3.93M
39
+ - K=128: +0.07% PPL gap vs full attention
40
+ - K=256: +0.01% PPL gap vs full attention
41
+
42
+ Broad-layer experiments:
43
+
44
+ - `checkpoints_all36_d128_block/protected/search_step_500_keep.pt` is the best
45
+ all-36 checkpoint observed so far.
46
+ - All-36 step 500: recall@K=0.816, PPL gap +3.23%.
47
+ - All-36 step 750 regressed to +3.96% despite stable recall.
48
+ - Per-layer mass@K identified L00/L01/L02 as the weak early layers.
49
+ - A follow-up all32 run reserves full attention on `[0, 1, 2, 35]` and trains
50
+ layers `3..34`; checkpoints will be mirrored here as they become useful.
51
+
52
+ ## Important results
53
+
54
+ ### Clean six-layer block-causal result
55
+
56
+ | K | Recall@K | mass@K | sparse PPL | PPL gap |
57
+ |---:|---:|---:|---:|---:|
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  | 128 | 0.744 | 0.787 | 30.47 | +0.07% |
59
  | 256 | 0.879 | 0.953 | 30.45 | +0.01% |
60
  | 512 | n/a | n/a | 30.45 | +0.01% |
61
 
62
+ The large negative PPL gaps from earlier packed-with-leakage experiments are
63
+ not used as clean claims. With block-causal masking, the robust claim is
64
+ full-attention parity on the six-layer pilot.
65
+
66
+ ### Quest-style page baseline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  | Method | K | Recall@K | mass@K | PPL | PPL gap |
69
  |---|---:|---:|---:|---:|---:|
 
72
  | learned search exact | 256 | 0.879 | 0.953 | 30.45 | +0.01% |
73
  | Quest-style page | 256 | 0.838 | 0.909 | 30.45 | +0.03% |
74
 
75
+ Paired 32-batch NLL evaluation:
 
 
 
 
 
 
 
76
 
77
+ | K | full PPL | learned PPL | Quest PPL | learned - Quest NLL delta, 95% CI | Read |
78
+ |---:|---:|---:|---:|---:|---|
79
  | 128 | 28.03 | 28.07 | 28.01 | +0.00205 `[+0.00160, +0.00251]` | Quest slightly better |
80
  | 256 | 28.03 | 28.04 | 28.04 | -0.00005 `[-0.00029, +0.00018]` | statistical tie |
81
 
82
+ The honest claim is retrieval-fidelity and ANN-compatibility, not a PPL win
83
+ over Quest.
 
 
84
 
85
+ ### FAISS/HNSW compatibility
86
 
87
+ The corrected clean FAISS path builds per-segment HNSW indexes when a
88
+ block-causal mask is present.
 
 
 
89
 
90
+ | Method | K | PPL | PPL gap |
91
+ |---|---:|---:|---:|
92
+ | learned exact | 128 | 30.47 | +0.07% |
93
+ | learned FAISS/HNSW | 128 | 30.47 | +0.09% |
94
+ | learned exact | 256 | 30.45 | +0.01% |
95
+ | learned FAISS/HNSW | 256 | 30.46 | +0.04% |
96
 
97
+ This validates that the learned search vectors are compatible with
98
+ off-the-shelf ANN. It is not a wall-clock result: the prototype uses CPU FAISS
99
+ and per-forward index construction.
 
100
 
101
+ ### All-36 result so far
102
 
103
+ | Step | Recall@K eval | PPL gap |
104
+ |---:|---:|---:|
105
+ | 250 | 0.805 | +6.27% |
106
+ | 500 | 0.816 | +3.23% |
107
+ | 750 | 0.817 | +3.96% |
108
 
109
+ All-36 is feasible but not parity under current hyperparameters. Step 500 is
110
+ kept because it is the best observed PPL checkpoint.
111
 
112
+ Per-layer step-500 mass@K at K=128:
 
 
 
 
113
 
114
+ | Layer | raw-QK | learned | delta |
115
+ |---:|---:|---:|---:|
116
+ | L00 | 0.922 | 0.780 | -0.142 |
117
+ | L01 | 0.918 | 0.851 | -0.067 |
118
+ | L02 | 0.939 | 0.899 | -0.040 |
119
+ | L03 | 0.939 | 0.924 | -0.015 |
120
+ | L04 | 0.944 | 0.933 | -0.011 |
121
+ | L05 | 0.964 | 0.947 | -0.017 |
122
+ | L06 | 0.956 | 0.936 | -0.020 |
123
+ | L07 | 0.982 | 0.982 | +0.000 |
124
+ | L08 | 0.971 | 0.970 | -0.001 |
125
+ | L09 | 0.959 | 0.976 | +0.017 |
126
+ | L20 | 0.959 | 0.975 | +0.016 |
127
+ | L21 | 0.966 | 0.979 | +0.014 |
128
+ | L34 | 0.976 | 0.960 | -0.016 |
129
+ | L35 | 0.980 | 0.967 | -0.013 |
130
+ | avg | 0.966 | 0.960 | -0.006 |
131
 
132
+ The next run reserves `[0, 1, 2, 35]` and trains layers `3..34`.
133
+
134
+ ## Checkpoints
135
+
136
+ Important checkpoint paths in this HF repo:
137
+
138
+ - `checkpoints_block_d128/search_step_1000.pt`: clean six-layer d128 parity checkpoint.
139
+ - `checkpoints_all36_d128_block/protected/search_step_500_keep.pt`: best observed all-36 checkpoint so far.
140
+ - `checkpoints_all36_d128_block/search_step_800.pt`: latest all-36 checkpoint before stopping for analysis.
141
+ - `checkpoints_all32_d128_block_reserve_0_1_2_35/`: active follow-up, uploaded as useful checkpoints are saved.
142
+
143
+ These checkpoints contain the trained search projection module and optimizer
144
+ state. They do not contain or modify the base Qwen model weights.
145
+
146
+ ## Limitations
147
+
148
+ - No production wall-clock speedup has been measured.
149
+ - No GPU-resident ANN or fused sparse attention kernel yet.
150
+ - No autoregressive KV-cache integration yet.
151
+ - Dynamic indexing is currently supported only by a retrieval-mass proxy.
152
+ - Main clean results are single-model and mostly single-seed.
153
+ - All-36 broad substitution is not full-attention parity yet.
154
+
155
+ Use the GitHub repository for runnable code, scripts, and the LaTeX paper draft.