File size: 9,923 Bytes
6df7db9
 
 
 
 
 
 
 
 
 
 
 
1f13233
 
6df7db9
 
 
 
 
 
 
 
 
 
1f13233
 
 
6df7db9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e73334
 
 
 
 
6df7db9
 
 
 
 
 
 
e12b862
 
 
 
6df7db9
 
 
 
 
57c6b5b
 
 
 
 
 
 
 
6df7db9
dbb833b
6df7db9
 
dbb833b
 
 
6df7db9
 
dbb833b
6df7db9
dbb833b
6df7db9
 
dbb833b
6df7db9
 
 
 
 
 
dbb833b
6df7db9
 
 
1f13233
5e73334
1f13233
6df7db9
 
 
 
 
1f13233
6df7db9
 
1f13233
6df7db9
1f13233
6df7db9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f13233
5e73334
 
6df7db9
5e73334
8b19534
 
 
5e73334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b19534
9040a9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e73334
 
 
 
9040a9d
8b19534
 
 
 
 
 
 
 
6df7db9
 
 
 
 
 
 
5e73334
 
 
6df7db9
 
 
 
 
 
 
 
 
 
 
5e73334
 
 
6df7db9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
---
license: mit
library_name: transformers
tags:
  - sparse-attention
  - approximate-nearest-neighbors
  - faiss
  - qwen3
  - long-context
base_model: Qwen/Qwen3-4B-Instruct-2507
---

# ann-sparseattention

This repository mirrors checkpoints and result artifacts for
[`unixsysdev/ann-sparseattention`](https://github.com/unixsysdev/ann-sparseattention).

The method trains tiny per-layer query/key "search projections" on a frozen
LLM so that attention-relevant keys are nearest neighbors in a low-dimensional
search space. At inference, candidate selection can be done with standard ANN
machinery such as FAISS HNSW, then ordinary attention is computed over the
retrieved native KV vectors.

This is a research prototype, not a production sparse-attention runtime.

## Current status

Validated clean pilot:

- Base model: `Qwen/Qwen3-4B-Instruct-2507`
- Dataset: WikiText-103
- Context: 4096 tokens
- Clean masking: packed block-causal segment isolation
- Recommended clean checkpoint:
  `checkpoints_block_d128/search_step_1000.pt`
- Trained layers: `[4, 8, 12, 16, 20, 24]`
- `d_search=128`
- Trainable parameters: 3.93M
- K=128: +0.07% PPL gap vs full attention
- K=256: +0.01% PPL gap vs full attention

Broad-layer experiments:

- `checkpoints_all36_d128_block/protected/search_step_500_keep.pt` is the best
  all-36 checkpoint observed so far.
- All-36 step 500: recall@K=0.816, PPL gap +3.23%.
- All-36 step 750 regressed to +3.96% despite stable recall.
- Per-layer mass@K identified L00/L01/L02 as the weak early layers.
- The all32 reserved-edge run reserves full attention on `[0, 1, 2, 35]` and
  trains layers `3..34`. Final step 1000: recall@K=0.825, +1.746% PPL gap in
  training eval, and 20.97M trained search-projection parameters. The exact
  K-sweep gives +0.590% PPL gap at K=128 and -0.062% at K=256 on a small clean
  block-causal slice.

## Important results

### Clean six-layer block-causal result

| K | Recall@K | mass@K | sparse PPL | PPL gap |
|---:|---:|---:|---:|---:|
| 128 | 0.744 | 0.787 | 30.47 | +0.07% |
| 256 | 0.879 | 0.953 | 30.45 | +0.01% |
| 512 | n/a | n/a | 30.45 | +0.01% |

The large negative PPL gaps from earlier packed-with-leakage experiments are
not used as clean claims. With block-causal masking, the robust claim is
full-attention parity on the six-layer pilot.

### Quest-style page baseline

| Method | K | Recall@K | mass@K | PPL | PPL gap |
|---|---:|---:|---:|---:|---:|
| learned search exact | 128 | 0.744 | 0.787 | 30.47 | +0.07% |
| Quest-style page | 128 | 0.669 | 0.727 | 30.41 | -0.11% |
| learned search exact | 256 | 0.879 | 0.953 | 30.45 | +0.01% |
| Quest-style page | 256 | 0.838 | 0.909 | 30.45 | +0.03% |

Paired 32-batch NLL evaluation:

| K | full PPL | learned PPL | Quest PPL | learned - Quest NLL delta, 95% CI | Read |
|---:|---:|---:|---:|---:|---|
| 128 | 28.03 | 28.07 | 28.01 | +0.00205 `[+0.00160, +0.00251]` | Quest slightly better |
| 256 | 28.03 | 28.04 | 28.04 | -0.00005 `[-0.00029, +0.00018]` | statistical tie |

The honest claim is retrieval-fidelity and ANN-compatibility, not a PPL win
over Quest.

### FAISS/HNSW compatibility

The corrected clean FAISS path builds per-segment HNSW indexes when a
block-causal mask is present.

| Method | K | PPL | PPL gap |
|---|---:|---:|---:|
| learned exact | 128 | 30.47 | +0.07% |
| learned FAISS/HNSW | 128 | 30.47 | +0.09% |
| learned exact | 256 | 30.45 | +0.01% |
| learned FAISS/HNSW | 256 | 30.46 | +0.04% |

This validates that the learned search vectors are compatible with
off-the-shelf ANN. It is not a wall-clock result: the prototype uses CPU FAISS
and per-forward index construction.

### All-36 and all32 broad-layer results

| Step | Recall@K eval | PPL gap |
|---:|---:|---:|
| 250 | 0.805 | +6.27% |
| 500 | 0.816 | +3.23% |
| 750 | 0.817 | +3.96% |

All-36 is feasible but not parity under current hyperparameters. Step 500 is
kept because it is the best observed PPL checkpoint.

Per-layer step-500 mass@K at K=128:

| Layer | raw-QK | learned | delta |
|---:|---:|---:|---:|
| L00 | 0.922 | 0.780 | -0.142 |
| L01 | 0.918 | 0.851 | -0.067 |
| L02 | 0.939 | 0.899 | -0.040 |
| L03 | 0.939 | 0.924 | -0.015 |
| L04 | 0.944 | 0.933 | -0.011 |
| L05 | 0.964 | 0.947 | -0.017 |
| L06 | 0.956 | 0.936 | -0.020 |
| L07 | 0.982 | 0.982 | +0.000 |
| L08 | 0.971 | 0.970 | -0.001 |
| L09 | 0.959 | 0.976 | +0.017 |
| L20 | 0.959 | 0.975 | +0.016 |
| L21 | 0.966 | 0.979 | +0.014 |
| L34 | 0.976 | 0.960 | -0.016 |
| L35 | 0.980 | 0.967 | -0.013 |
| avg | 0.966 | 0.960 | -0.006 |

This diagnostic motivated reserving `[0, 1, 2, 35]` as full-attention layers and
training only layers `3..34`.

Final all32 reserved-edge training trajectory:

| Step | Recall@K eval | PPL gap | Read |
|---:|---:|---:|---|
| 250 | 0.812 | +2.283% | already better than all36 best training eval |
| 500 | 0.823 | +1.753% | converged to final quality band |
| 750 | 0.825 | +1.943% | small eval fluctuation |
| 1000 | 0.825 | +1.746% | final checkpoint; essentially tied with step 500 |

The all32 checkpoint is the current broad-substitution result. It is not
full-attention parity at K=128 in training eval, but it reduces the all36
quality cost while still substituting 32 of 36 layers. Post-hoc
`compare_retrieval` on step 1000 shows learned retrieval matches raw-QK mass on
the substituted layers: at K=128, learned mass is 0.971 vs raw-QK 0.969; at
K=256, learned mass is 0.993 vs raw-QK 0.994.

Exact K-sweep on the final all32 checkpoint, 2-batch clean block-causal slice
(`PPL_full = 20.5349`):

| K | mass@K | Recall@K | sparse PPL | PPL gap |
|---:|---:|---:|---:|---:|
| 16 | 0.546 | 0.518 | 24.86 | +21.064% |
| 32 | 0.627 | 0.572 | 21.85 | +6.422% |
| 64 | 0.722 | 0.652 | 20.94 | +1.974% |
| 128 | 0.807 | 0.746 | 20.66 | +0.590% |
| 256 | 0.902 | 0.876 | 20.52 | -0.062% |

K=512 is intentionally omitted from this table. The current script produced a
valid sparse-attention PPL line for K=512 but zero mass/recall, which is an
edge-case bug in the metric path when K exceeds the number of valid causal keys
for most same-segment queries. It should be rerun after fixing the metric
handling; the publishable sweep for now is K <= 256.

Coverage now looks like a real deployment knob:

| Configuration | Layers substituted | Coverage | PPL gap | Read |
|---|---:|---:|---:|---|
| Clean six-layer pilot | 6/36 | 17% | +0.07% at K=128 | quality-preserving pilot |
| all32 reserved-edge | 32/36 | 89% | +1.746% train eval; +0.590% exact sweep | near-parity broad substitution |
| all36 | 36/36 | 100% | +3.23% best observed | full substitution costs quality |

This is not yet enough to claim an optimal coverage ratio, but it suggests the
best deployment point is intermediate rather than "sparsify everything." A
12/18/20-layer coverage sweep is the next clean experiment.

## Positioning against related methods

The paper frames this method as closest in asymptotic shape to Reformer and
closest in practical baseline behavior to Quest.

| Method | Selection mechanism | Query-aware | Trained | Asymptotic | Exact softmax |
|---|---|---|---|---|---|
| Full attention | all keys | n/a | n/a | O(N²) | yes |
| Reformer | LSH hashing | yes | no | O(N log N) | over bucket |
| Performer | random features | n/a | no | O(N) | no |
| BigBird | window + random + global | mostly no | no | O(N) | over pattern |
| Longformer | sliding window + global | mostly no | no | O(N) | over pattern |
| NSA-style methods | block compression/selection | partial | partial | O(N²) proxy | yes |
| Quest | min/max page heuristic | yes | no | O(N) | over pages |
| This work | trained low-dim retrieval | yes | yes | O(N log N) | over retrieved set |

This is a design-positioning table, not a claim of completed production
superiority. The clean result proves the approach for the six-layer pilot, and
the all32 reserved-edge run shows broad substitution can get close to parity
when weak edge layers remain full attention. All36 full substitution is still
not parity.

This method targets a different deployment scenario than native
sliding-window/state-space/hybrid architectures such as Mistral-style sliding
window, Mamba, or Qwen3.6 Gated DeltaNet hybrids. Those models are trained from
scratch with their sparse or hybrid mechanism in place. This work is post-hoc:
train a base model with full attention for maximum expressivity, then add
lightweight retrieval projections afterward to make inference sub-linear without
changing base weights.

## Checkpoints

Important checkpoint paths in this HF repo:

- `checkpoints_block_d128/search_step_1000.pt`: clean six-layer d128 parity checkpoint.
- `checkpoints_all36_d128_block/protected/search_step_500_keep.pt`: best observed all-36 checkpoint so far.
- `checkpoints_all36_d128_block/search_step_800.pt`: latest all-36 checkpoint before stopping for analysis.
- `checkpoints_all32_d128_block_reserve_0_1_2_35/search_step_1000.pt`: final all32 reserved-edge checkpoint.
- `checkpoints_all32_d128_block_reserve_0_1_2_35/search_step_1000.compare_retrieval.json`: all32 per-layer retrieval comparison.
- `checkpoints_all32_d128_block_reserve_0_1_2_35/search_step_1000.k_sweep_exact.json`: all32 exact K-sweep.

These checkpoints contain the trained search projection module and optimizer
state. They do not contain or modify the base Qwen model weights.

## Limitations

- No production wall-clock speedup has been measured.
- No GPU-resident ANN or fused sparse attention kernel yet.
- No autoregressive KV-cache integration yet.
- Dynamic indexing is currently supported only by a retrieval-mass proxy.
- Main clean results are single-model and mostly single-seed.
- All-36 broad substitution is not full-attention parity.
- The all32 result is near parity on a small slice but still needs larger eval
  slices, task benchmarks, and a coverage Pareto sweep.

Use the GitHub repository for runnable code, scripts, and the LaTeX paper draft.