| --- |
| license: mit |
| library_name: transformers |
| tags: |
| - sparse-attention |
| - approximate-nearest-neighbors |
| - faiss |
| - qwen3 |
| - long-context |
| base_model: Qwen/Qwen3-4B-Instruct-2507 |
| --- |
| |
| # ann-sparseattention |
|
|
| This repository mirrors checkpoints and result artifacts for |
| [`unixsysdev/ann-sparseattention`](https://github.com/unixsysdev/ann-sparseattention). |
|
|
| The method trains tiny per-layer query/key "search projections" on a frozen |
| LLM so that attention-relevant keys are nearest neighbors in a low-dimensional |
| search space. At inference, candidate selection can be done with standard ANN |
| machinery such as FAISS HNSW, then ordinary attention is computed over the |
| retrieved native KV vectors. |
|
|
| This is a research prototype, not a production sparse-attention runtime. |
|
|
| ## Extremely preliminary llama.cpp runtime GGUFs |
|
|
| Merged ANN GGUFs from the llama.cpp runtime branch are available under |
| [`gguf/`](https://huggingface.co/datasysdev/ann-sparseattention/tree/main/gguf): |
|
|
| | variant | hosted file | intended use | |
| |---|---|---| |
| | 6-layer pilot | [`Qwen3-4B-Instruct-2507-F16-ann-6layer-k128-v2.gguf`](https://huggingface.co/datasysdev/ann-sparseattention/blob/main/gguf/Qwen3-4B-Instruct-2507-F16-ann-6layer-k128-v2.gguf) | best strict HNSW password-recall behavior so far | |
| | all32 reserved-edge | [`Qwen3-4B-Instruct-2507-F16-ann-all32-k128-v2.gguf`](https://huggingface.co/datasysdev/ann-sparseattention/blob/main/gguf/Qwen3-4B-Instruct-2507-F16-ann-all32-k128-v2.gguf) | strongest broad-layer exact-runtime speed/quality candidate | |
| | all36 full substitution | [`Qwen3-4B-Instruct-2507-F16-ann-all36-k128-v2.gguf`](https://huggingface.co/datasysdev/ann-sparseattention/blob/main/gguf/Qwen3-4B-Instruct-2507-F16-ann-all36-k128-v2.gguf) | full-substitution stress test; visibly weaker quality | |
|
|
| The corresponding runtime branch is |
| [`feat/llama-ann-runtime`](https://github.com/unixsysdev/ann-sparseattention/tree/feat/llama-ann-runtime). |
| That branch intentionally contains only the runtime snapshot and its logs. |
|
|
| These GGUFs are extremely preliminary. The current llama.cpp integration is |
| useful for reproducing CPU/ROCm correctness and early speed tests, but it is |
| not a production sparse-attention runtime. The HNSW path performs real |
| approximate ANN candidate selection, but currently rebuilds indices during |
| decode; it validates behavior rather than final latency. The exact top-K path |
| is the current correctness oracle and showed an early 33.6k-context speed |
| signal for all32, but prompt quality tests remain synthetic and incomplete. |
|
|
| ## Current status |
|
|
| Validated clean pilot: |
|
|
| - Base model: `Qwen/Qwen3-4B-Instruct-2507` |
| - Dataset: WikiText-103 |
| - Context: 4096 tokens |
| - Clean masking: packed block-causal segment isolation |
| - Recommended clean checkpoint: |
| `checkpoints_block_d128/search_step_1000.pt` |
| - Trained layers: `[4, 8, 12, 16, 20, 24]` |
| - `d_search=128` |
| - Trainable parameters: 3.93M |
| - K=128: +0.07% PPL gap vs full attention |
| - K=256: +0.01% PPL gap vs full attention |
|
|
| Broad-layer experiments: |
|
|
| - `checkpoints_all36_d128_block/protected/search_step_500_keep.pt` is the best |
| all-36 checkpoint observed so far. |
| - All-36 step 500: recall@K=0.816, PPL gap +3.23%. |
| - All-36 step 750 regressed to +3.96% despite stable recall. |
| - Per-layer mass@K identified L00/L01/L02 as the weak early layers. |
| - The all32 reserved-edge run reserves full attention on `[0, 1, 2, 35]` and |
| trains layers `3..34`. Final step 1000: recall@K=0.825, +1.746% PPL gap in |
| training eval, and 20.97M trained search-projection parameters. The exact |
| K-sweep gives +0.590% PPL gap at K=128 and -0.062% at K=256 on a small clean |
| block-causal slice. |
|
|
| ## Important results |
|
|
| ### Clean six-layer block-causal result |
|
|
| | K | Recall@K | mass@K | sparse PPL | PPL gap | |
| |---:|---:|---:|---:|---:| |
| | 128 | 0.744 | 0.787 | 30.47 | +0.07% | |
| | 256 | 0.879 | 0.953 | 30.45 | +0.01% | |
| | 512 | n/a | n/a | 30.45 | +0.01% | |
|
|
| The large negative PPL gaps from earlier packed-with-leakage experiments are |
| not used as clean claims. With block-causal masking, the robust claim is |
| full-attention parity on the six-layer pilot. |
|
|
| ### Quest-style page baseline |
|
|
| | Method | K | Recall@K | mass@K | PPL | PPL gap | |
| |---|---:|---:|---:|---:|---:| |
| | learned search exact | 128 | 0.744 | 0.787 | 30.47 | +0.07% | |
| | Quest-style page | 128 | 0.669 | 0.727 | 30.41 | -0.11% | |
| | learned search exact | 256 | 0.879 | 0.953 | 30.45 | +0.01% | |
| | Quest-style page | 256 | 0.838 | 0.909 | 30.45 | +0.03% | |
|
|
| Paired 32-batch NLL evaluation: |
|
|
| | K | full PPL | learned PPL | Quest PPL | learned - Quest NLL delta, 95% CI | Read | |
| |---:|---:|---:|---:|---:|---| |
| | 128 | 28.03 | 28.07 | 28.01 | +0.00205 `[+0.00160, +0.00251]` | Quest slightly better | |
| | 256 | 28.03 | 28.04 | 28.04 | -0.00005 `[-0.00029, +0.00018]` | statistical tie | |
|
|
| The honest claim is retrieval-fidelity and ANN-compatibility, not a PPL win |
| over Quest. |
|
|
| ### FAISS/HNSW compatibility |
|
|
| The corrected clean FAISS path builds per-segment HNSW indexes when a |
| block-causal mask is present. |
|
|
| | Method | K | PPL | PPL gap | |
| |---|---:|---:|---:| |
| | learned exact | 128 | 30.47 | +0.07% | |
| | learned FAISS/HNSW | 128 | 30.47 | +0.09% | |
| | learned exact | 256 | 30.45 | +0.01% | |
| | learned FAISS/HNSW | 256 | 30.46 | +0.04% | |
|
|
| This validates that the learned search vectors are compatible with |
| off-the-shelf ANN. It is not a wall-clock result: the prototype uses CPU FAISS |
| and per-forward index construction. |
|
|
| ### All-36 and all32 broad-layer results |
|
|
| | Step | Recall@K eval | PPL gap | |
| |---:|---:|---:| |
| | 250 | 0.805 | +6.27% | |
| | 500 | 0.816 | +3.23% | |
| | 750 | 0.817 | +3.96% | |
|
|
| All-36 is feasible but not parity under current hyperparameters. Step 500 is |
| kept because it is the best observed PPL checkpoint. |
|
|
| Per-layer step-500 mass@K at K=128: |
|
|
| | Layer | raw-QK | learned | delta | |
| |---:|---:|---:|---:| |
| | L00 | 0.922 | 0.780 | -0.142 | |
| | L01 | 0.918 | 0.851 | -0.067 | |
| | L02 | 0.939 | 0.899 | -0.040 | |
| | L03 | 0.939 | 0.924 | -0.015 | |
| | L04 | 0.944 | 0.933 | -0.011 | |
| | L05 | 0.964 | 0.947 | -0.017 | |
| | L06 | 0.956 | 0.936 | -0.020 | |
| | L07 | 0.982 | 0.982 | +0.000 | |
| | L08 | 0.971 | 0.970 | -0.001 | |
| | L09 | 0.959 | 0.976 | +0.017 | |
| | L20 | 0.959 | 0.975 | +0.016 | |
| | L21 | 0.966 | 0.979 | +0.014 | |
| | L34 | 0.976 | 0.960 | -0.016 | |
| | L35 | 0.980 | 0.967 | -0.013 | |
| | avg | 0.966 | 0.960 | -0.006 | |
|
|
| This diagnostic motivated reserving `[0, 1, 2, 35]` as full-attention layers and |
| training only layers `3..34`. |
|
|
| Final all32 reserved-edge training trajectory: |
|
|
| | Step | Recall@K eval | PPL gap | Read | |
| |---:|---:|---:|---| |
| | 250 | 0.812 | +2.283% | already better than all36 best training eval | |
| | 500 | 0.823 | +1.753% | converged to final quality band | |
| | 750 | 0.825 | +1.943% | small eval fluctuation | |
| | 1000 | 0.825 | +1.746% | final checkpoint; essentially tied with step 500 | |
|
|
| The all32 checkpoint is the current broad-substitution result. It is not |
| full-attention parity at K=128 in training eval, but it reduces the all36 |
| quality cost while still substituting 32 of 36 layers. Post-hoc |
| `compare_retrieval` on step 1000 shows learned retrieval matches raw-QK mass on |
| the substituted layers: at K=128, learned mass is 0.971 vs raw-QK 0.969; at |
| K=256, learned mass is 0.993 vs raw-QK 0.994. |
|
|
| Exact K-sweep on the final all32 checkpoint, 2-batch clean block-causal slice |
| (`PPL_full = 20.5349`): |
|
|
| | K | mass@K | Recall@K | sparse PPL | PPL gap | |
| |---:|---:|---:|---:|---:| |
| | 16 | 0.546 | 0.518 | 24.86 | +21.064% | |
| | 32 | 0.627 | 0.572 | 21.85 | +6.422% | |
| | 64 | 0.722 | 0.652 | 20.94 | +1.974% | |
| | 128 | 0.807 | 0.746 | 20.66 | +0.590% | |
| | 256 | 0.902 | 0.876 | 20.52 | -0.062% | |
|
|
| K=512 is intentionally omitted from this table. The current script produced a |
| valid sparse-attention PPL line for K=512 but zero mass/recall, which is an |
| edge-case bug in the metric path when K exceeds the number of valid causal keys |
| for most same-segment queries. It should be rerun after fixing the metric |
| handling; the publishable sweep for now is K <= 256. |
|
|
| Coverage now looks like a real deployment knob: |
|
|
| | Configuration | Layers substituted | Coverage | PPL gap | Read | |
| |---|---:|---:|---:|---| |
| | Clean six-layer pilot | 6/36 | 17% | +0.07% at K=128 | quality-preserving pilot | |
| | all32 reserved-edge | 32/36 | 89% | +1.746% train eval; +0.590% exact sweep | near-parity broad substitution | |
| | all36 | 36/36 | 100% | +3.23% best observed | full substitution costs quality | |
|
|
| This is not yet enough to claim an optimal coverage ratio, but it suggests the |
| best deployment point is intermediate rather than "sparsify everything." A |
| 12/18/20-layer coverage sweep is the next clean experiment. |
|
|
| ## Positioning against related methods |
|
|
| The paper frames this method as closest in asymptotic shape to Reformer and |
| closest in practical baseline behavior to Quest. |
|
|
| | Method | Selection mechanism | Query-aware | Trained | Asymptotic | Exact softmax | |
| |---|---|---|---|---|---| |
| | Full attention | all keys | n/a | n/a | O(N²) | yes | |
| | Reformer | LSH hashing | yes | no | O(N log N) | over bucket | |
| | Performer | random features | n/a | no | O(N) | no | |
| | BigBird | window + random + global | mostly no | no | O(N) | over pattern | |
| | Longformer | sliding window + global | mostly no | no | O(N) | over pattern | |
| | NSA-style methods | block compression/selection | partial | partial | O(N²) proxy | yes | |
| | Quest | min/max page heuristic | yes | no | O(N) | over pages | |
| | This work | trained low-dim retrieval | yes | yes | O(N log N) | over retrieved set | |
|
|
| This is a design-positioning table, not a claim of completed production |
| superiority. The clean result proves the approach for the six-layer pilot, and |
| the all32 reserved-edge run shows broad substitution can get close to parity |
| when weak edge layers remain full attention. All36 full substitution is still |
| not parity. |
|
|
| This method targets a different deployment scenario than native |
| sliding-window/state-space/hybrid architectures such as Mistral-style sliding |
| window, Mamba, or Qwen3.6 Gated DeltaNet hybrids. Those models are trained from |
| scratch with their sparse or hybrid mechanism in place. This work is post-hoc: |
| train a base model with full attention for maximum expressivity, then add |
| lightweight retrieval projections afterward to make inference sub-linear without |
| changing base weights. |
|
|
| ## Checkpoints |
|
|
| Important checkpoint paths in this HF repo: |
|
|
| - `checkpoints_block_d128/search_step_1000.pt`: clean six-layer d128 parity checkpoint. |
| - `checkpoints_all36_d128_block/protected/search_step_500_keep.pt`: best observed all-36 checkpoint so far. |
| - `checkpoints_all36_d128_block/search_step_800.pt`: latest all-36 checkpoint before stopping for analysis. |
| - `checkpoints_all32_d128_block_reserve_0_1_2_35/search_step_1000.pt`: final all32 reserved-edge checkpoint. |
| - `checkpoints_all32_d128_block_reserve_0_1_2_35/search_step_1000.compare_retrieval.json`: all32 per-layer retrieval comparison. |
| - `checkpoints_all32_d128_block_reserve_0_1_2_35/search_step_1000.k_sweep_exact.json`: all32 exact K-sweep. |
|
|
| These checkpoints contain the trained search projection module and optimizer |
| state. They do not contain or modify the base Qwen model weights. |
|
|
| ## Limitations |
|
|
| - No production wall-clock speedup has been measured. |
| - No GPU-resident ANN or fused sparse attention kernel yet. |
| - No autoregressive KV-cache integration yet. |
| - Dynamic indexing is currently supported only by a retrieval-mass proxy. |
| - Main clean results are single-model and mostly single-seed. |
| - All-36 broad substitution is not full-attention parity. |
| - The all32 result is near parity on a small slice but still needs larger eval |
| slices, task benchmarks, and a coverage Pareto sweep. |
|
|
| Use the GitHub repository for runnable code, scripts, and the LaTeX paper draft. |
|
|