datasysdev commited on
Commit
e12b862
·
verified ·
1 Parent(s): 57c6b5b

Expand model card with clean results and artifact guide

Browse files
Files changed (1) hide show
  1. README.md +188 -29
README.md CHANGED
@@ -3,34 +3,94 @@ license: mit
3
  base_model: Qwen/Qwen3-4B-Instruct-2507
4
  tags:
5
  - sparse-attention
 
6
  - ann
7
  - qwen3
8
  - retrieval
 
9
  - research-artifact
 
10
  ---
11
 
12
  # ANN Sparse Attention Checkpoints
13
 
14
- Research artifact for distillation-trained ANN-friendly search projections for sparse attention.
15
 
16
- Base model: `Qwen/Qwen3-4B-Instruct-2507`.
17
 
18
- ## Current Clean Result
19
 
20
- The clean methodology is the packed block-causal d128 run in `checkpoints_block_d128/`.
21
- Packed examples are isolated with per-document `segment_ids`, reset `position_ids`, and a 4D block-causal attention mask. Retrieval, loss masking, mass@K, and recall@K use the same segment-causal eligibility mask.
22
 
23
- Clean block-causal d128 checkpoint:
24
 
25
- - `checkpoints_block_d128/search_step_1000.pt`
26
- - 6 trained layers: `[4, 8, 12, 16, 20, 24]`
27
- - `d_search=128`, 3.93M trainable parameters
28
- - K=128 exact learned search: PPL gap `+0.07%`, mass@K `0.787`, recall@K `0.744`
29
- - K=256 exact learned search: PPL gap `+0.01%`, mass@K `0.953`, recall@K `0.879`
30
 
31
- Interpretation: clean block-causal evaluation shows full-attention parity, not a clean denoising/improvement claim.
 
 
 
32
 
33
- ## Clean Per-layer Retrieval, K=128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  From `checkpoints_block_d128/search_step_1000.compare_retrieval.json`:
36
 
@@ -44,11 +104,31 @@ From `checkpoints_block_d128/search_step_1000.compare_retrieval.json`:
44
  | 24 | 0.978 | 0.984 |
45
  | avg | 0.969 | 0.973 |
46
 
47
- With segment isolation, early trained layers are not uniquely diffuse or hard; all six trained layers have high oracle mass and learned projections match/slightly exceed raw-QK retrieval.
 
 
48
 
49
  ## Quest-style Page Baseline
50
 
51
- From `checkpoints_block_d128/search_step_1000.quest_page16.json`, using page size 16, native post-RoPE Q/K min/max pages, and the same block-causal eligibility mask:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  | Method | K | Recall@K | mass@K | PPL | PPL gap |
54
  |---|---:|---:|---:|---:|---:|
@@ -61,27 +141,106 @@ Both methods are effectively full-attention parity on PPL. Learned projections r
61
 
62
  ## Packed Leakage-confounded Ablations
63
 
64
- The packed d64/d128/d256 runs are included for capacity-scaling history and should not be used for clean quality claims because packed examples could attend across document boundaries.
65
 
66
  Packed d_search ablation at K=128:
67
 
68
- | d_search | learned mass@K=128 | raw-QK oracle | learned/oracle | final PPL gap |
 
 
 
 
 
 
 
 
69
  |---|---:|---:|---:|---:|
70
- | 64 | 0.492 | 0.488 | 1.01x | +2.39% |
71
- | 128 | 0.503 | 0.488 | 1.03x | -1.81% |
72
- | 256 | 0.509 | 0.488 | 1.04x | -1.85% |
73
 
74
- The large negative packed K-sweep gaps were leakage-confounded and should be treated as historical/debugging evidence only, not as the headline.
75
 
76
- ## Folder Guide
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- - `checkpoints_block_d128/`: clean block-causal d128 checkpoint and JSON eval artifacts. Use this for current claims.
79
- - `checkpoints_packed_d64/`, `checkpoints_packed_d128/`, `checkpoints_packed_d256/`: leakage-confounded packed ablation checkpoints.
80
- - `checkpoints_d64/`: earlier unpacked d64 checkpoints.
81
- - `checkpoints/`: original pilot checkpoint and compare JSON.
 
82
 
83
- ## Code
84
 
85
- Source repo: https://github.com/unixsysdev/ann-sparseattention
86
 
87
- The repo README contains the current methodology notes and reproduction commands.
 
3
  base_model: Qwen/Qwen3-4B-Instruct-2507
4
  tags:
5
  - sparse-attention
6
+ - approximate-nearest-neighbor
7
  - ann
8
  - qwen3
9
  - retrieval
10
+ - attention
11
  - research-artifact
12
+ library_name: pytorch
13
  ---
14
 
15
  # ANN Sparse Attention Checkpoints
16
 
17
+ This repository contains checkpoint artifacts for a research prototype that trains tiny per-layer search projections on a frozen LLM, so dense attention can be approximated by retrieving a small causal key set in a learned low-dimensional space.
18
 
19
+ The associated source repo is [unixsysdev/ann-sparseattention](https://github.com/unixsysdev/ann-sparseattention). The GitHub repo contains the training/eval code; this Hugging Face repo stores the checkpoints and JSON result artifacts.
20
 
21
+ ## Status
22
 
23
+ Research artifact, not a deployable inference package.
 
24
 
25
+ The clean result is narrow but real: on a 6-layer Qwen3-4B pilot with packed block-causal WikiText evaluation, the learned d128 search projections preserve full-attention perplexity under exact sparse substitution.
26
 
27
+ What survives clean methodology:
 
 
 
 
28
 
29
+ - Full-attention parity on the block-causal d128 pilot.
30
+ - Strong teacher-attention mass recovery with learned projections.
31
+ - Learned search projections recover more teacher attention mass than the Quest-style page heuristic at the same token budget on this slice.
32
+ - The earlier negative PPL gaps from packed-with-leakage runs do **not** survive as a clean denoising headline.
33
 
34
+ What is not established yet:
35
+
36
+ - Wall-clock speedup. The current runtime is a correctness prototype.
37
+ - Confidence intervals across seeds.
38
+ - LongBench/RULER/needle downstream task quality.
39
+ - Dynamic decode-mode index insertion.
40
+ - Whole-model / all-layer substitution.
41
+ - GPU-resident ANN or fused sparse-attention kernels.
42
+
43
+ ## Base Model
44
+
45
+ - Base model: `Qwen/Qwen3-4B-Instruct-2507`
46
+ - Layers trained in pilot: `[4, 8, 12, 16, 20, 24]`
47
+ - Clean recommended checkpoint: `checkpoints_block_d128/search_step_1000.pt`
48
+ - Search dimension: `d_search=128`
49
+ - Trainable parameters: 3.93M total, about 0.1% of the base model
50
+ - Base model weights are **not** included here. These checkpoints contain only the learned search projection module and training metadata.
51
+
52
+ ## Folder Guide
53
+
54
+ Use `checkpoints_block_d128/` for current clean claims.
55
+
56
+ | Folder | Meaning | Use for claims? |
57
+ |---|---|---|
58
+ | `checkpoints_block_d128/` | Clean packed block-causal d128 run and eval artifacts | Yes |
59
+ | `checkpoints_packed_d64/` | Packed d64 leakage-confounded capacity run | Capacity history only |
60
+ | `checkpoints_packed_d128/` | Packed d128 leakage-confounded capacity run | Capacity history only |
61
+ | `checkpoints_packed_d256/` | Packed d256 leakage-confounded capacity run | Capacity history only |
62
+ | `checkpoints_d64/` | Earlier unpacked d64 checkpoints | Debug/history |
63
+ | `checkpoints/` | Original pilot checkpoint and compare JSON | Debug/history |
64
+
65
+ The clean block-causal run fixed the core packing issue by assigning each packed document a `segment_id`, resetting `position_ids`, and supplying a 4D block-causal attention mask so tokens can only attend causally within their own document.
66
+
67
+ ## Clean Block-Causal Result
68
+
69
+ Command used for the clean d128 checkpoint:
70
+
71
+ ```bash
72
+ python train.py --config pilot_d128_block
73
+ python k_sweep.py \
74
+ --ckpt /tmp/checkpoints_block_d128/search_step_1000.pt \
75
+ --K 128,256,512 \
76
+ --no-use-faiss
77
+ ```
78
+
79
+ Evaluation slice: 16 packed block-causal WikiText batches at 4K context.
80
+
81
+ `PPL_full = 30.44`
82
+
83
+ | K | Recall@K | mass@K | PPL_ANN | PPL gap |
84
+ |---|---:|---:|---:|---:|
85
+ | 128 | 0.744 | 0.787 | 30.47 | +0.07% |
86
+ | 256 | 0.879 | 0.953 | 30.45 | +0.01% |
87
+ | 512 | n/a | n/a | 30.45 | +0.01% |
88
+
89
+ K=512 has no meaningful mass/recall average on this WikiText slice because almost no same-segment queries have 512 valid causal keys. The PPL value is still shown, but K=512 should not be used as a retrieval-quality point for this dataset slice.
90
+
91
+ Interpretation: the clean result supports **quality-preserving sparse substitution**, not a claim that sparse attention improves over full attention.
92
+
93
+ ## Clean Per-layer Retrieval at K=128
94
 
95
  From `checkpoints_block_d128/search_step_1000.compare_retrieval.json`:
96
 
 
104
  | 24 | 0.978 | 0.984 |
105
  | avg | 0.969 | 0.973 |
106
 
107
+ This changes the interpretation from the earlier leakage-confounded pilot. With segment isolation, early trained layers are not diffuse or uniquely hard. All six trained layers have high raw-QK oracle mass, and learned projections match or slightly exceed raw-QK retrieval across the tested set.
108
+
109
+ The next deployment hypothesis is therefore: substitute all tested layers, then validate on a broader all-layer run.
110
 
111
  ## Quest-style Page Baseline
112
 
113
+ `quest_sweep.py` implements a Quest-style min/max page selector for comparison:
114
+
115
+ - Page size: 16
116
+ - Native post-RoPE Q/K min/max metadata
117
+ - Same block-causal token eligibility mask
118
+ - Same sparse-attention gather path
119
+
120
+ This is a correctness baseline, not an optimized Quest runtime.
121
+
122
+ Command:
123
+
124
+ ```bash
125
+ python quest_sweep.py \
126
+ --ckpt /tmp/checkpoints_block_d128/search_step_1000.pt \
127
+ --K 128,256,512 \
128
+ --page-size 16
129
+ ```
130
+
131
+ Same 16-batch clean block-causal eval slice:
132
 
133
  | Method | K | Recall@K | mass@K | PPL | PPL gap |
134
  |---|---:|---:|---:|---:|---:|
 
141
 
142
  ## Packed Leakage-confounded Ablations
143
 
144
+ The packed d64/d128/d256 runs are included because they are useful for understanding capacity scaling, but they should not be used for clean quality claims. Those runs allowed cross-document attention inside packed examples.
145
 
146
  Packed d_search ablation at K=128:
147
 
148
+ | d_search | Params | learned mass@K=128 | raw-QK oracle | learned/oracle | final PPL gap |
149
+ |---|---:|---:|---:|---:|---:|
150
+ | 64 | 1.97M | 0.492 | 0.488 | 1.01x | +2.39% |
151
+ | 128 | 3.93M | 0.503 | 0.488 | 1.03x | -1.81% |
152
+ | 256 | 7.86M | 0.509 | 0.488 | 1.04x | -1.85% |
153
+
154
+ The packed leakage-confounded K-sweep showed large negative PPL gaps:
155
+
156
+ | K | Recall@K | mass@K | PPL_ANN | PPL gap |
157
  |---|---:|---:|---:|---:|
158
+ | 128 | 0.166 | 0.256 | 203.63 | -9.36% |
159
+ | 256 | 0.233 | 0.318 | 207.06 | -7.83% |
160
+ | 512 | 0.339 | 0.409 | 211.93 | -5.66% |
161
 
162
+ A second leaked packed slice preserved the shape: K=128 `-8.78%`, K=256 `-7.59%`, K=512 `-6.21%`. These numbers are retained for transparency and debugging history. They should not be reported as the headline because the clean block-causal rerun shows parity, not denoising.
163
 
164
+ ## What the Checkpoints Contain
165
+
166
+ Each `.pt` file is a PyTorch checkpoint with the learned search projection module and config metadata. The base LLM is loaded separately from Hugging Face.
167
+
168
+ Example loading pattern from the source repo:
169
+
170
+ ```python
171
+ import torch
172
+ from transformers import AutoModelForCausalLM, AutoTokenizer
173
+ from config import Config
174
+ from model import SearchProjectionModule
175
+
176
+ ckpt = torch.load("checkpoints_block_d128/search_step_1000.pt", map_location="cpu", weights_only=False)
177
+ ckpt_cfg = ckpt["config"]
178
+
179
+ cfg = Config()
180
+ for key, value in ckpt_cfg.items():
181
+ if hasattr(cfg, key):
182
+ setattr(cfg, key, value)
183
+
184
+ base = AutoModelForCausalLM.from_pretrained(
185
+ cfg.base_model_name,
186
+ dtype=torch.bfloat16,
187
+ device_map="auto",
188
+ attn_implementation="sdpa",
189
+ )
190
+
191
+ layers = [
192
+ i for i in cfg.full_attention_layer_indices
193
+ if i not in cfg.reserved_full_attention_indices
194
+ ]
195
+ search = SearchProjectionModule(
196
+ d_model=base.config.hidden_size,
197
+ d_search=cfg.d_search,
198
+ layer_indices=layers,
199
+ use_mlp=cfg.use_mlp_proj,
200
+ ).to(base.device).to(torch.bfloat16)
201
+ search.load_state_dict(ckpt["search_module"])
202
+ search.eval()
203
+ ```
204
+
205
+ See the GitHub repo for full eval scripts and monkey-patched sparse-attention wrappers.
206
+
207
+ ## Runtime Caveat
208
+
209
+ The current `inference.py` path is a correctness prototype:
210
+
211
+ - Exact top-K path materializes dense `[B, L, L]` similarity and is for analysis.
212
+ - FAISS/HNSW path builds a CPU index per forward pass and transfers data across CPU/GPU.
213
+ - Gathered sparse attention still uses dense-style tensor expansion internally.
214
+
215
+ Therefore, any FLOP/scoring reductions are algorithmic estimates, not measured wall-clock speedups. A deployable runtime needs GPU-resident retrieval and a fused sparse/paged attention kernel.
216
+
217
+ ## Recommended Use
218
+
219
+ Use this repo for:
220
+
221
+ - Reproducing the clean d128 block-causal result.
222
+ - Inspecting search projection checkpoints.
223
+ - Comparing learned search retrieval against raw-QK and Quest-style page retrieval.
224
+ - Building follow-up experiments such as dynamic-index insertion or all-layer substitution.
225
+
226
+ Do not use this repo as:
227
+
228
+ - A drop-in accelerated inference engine.
229
+ - Evidence that sparse attention beats full attention on clean methodology.
230
+ - A complete comparison against all sparse-attention baselines.
231
+
232
+ ## Next Experiments
233
+
234
+ The most important follow-ups are:
235
 
236
+ 1. Dynamic-index demonstration during long generation.
237
+ 2. Multi-seed confidence intervals for block-causal d128.
238
+ 3. LongBench/RULER/needle task evaluation.
239
+ 4. All-layer substitution run.
240
+ 5. GPU-resident retrieval and decode-mode KV-cache integration.
241
 
242
+ ## Citation / Attribution
243
 
244
+ This is an in-progress research artifact. If you use it, cite the GitHub repo and this Hugging Face checkpoint repository.
245
 
246
+ Source: https://github.com/unixsysdev/ann-sparseattention