AttnVQ β Attention-Aware KV Cache Quantization
Training-free product vector quantization of the KV cache for long-context LLMs. AttnVQ fits small per-subspace codebooks with attention-weighted batched LBG (centroids weighted by key attention mass from GQA causal attention), but scores distortion by attention-output error (and key cosine / inner-product bias), not cache MSE. Calibration is light: 10β15 agent traces, ~15 s on GPU β enough to capture the model's K/V geometry (data-aware, not corpus-dependent).
Primary target: Laguna-XS.2 (model-agnostic). Only the 10 full-attention layers are compressed; 30 sliding-window layers stay fp16.
What this repo offers
| Component | Description |
|---|---|
generate.py |
Minimal inference: VQQuantizedCache β model.generate() |
vqkv/ |
Quantizers (ProductVQ, RoPESplit, scalar/KIVI baselines), attention-aware metrics, compressed cache |
benchmark.py |
Fit codebooks + cheap metrics (key cosine, attn-output error, ip-bias) on real cache dumps |
turbo_benchmark.py |
Faithful TurboQuant baseline (Haar rotation + Lloyd-Max + QJL) |
longbench_eval.py |
LongBench v1 proxy metrics + optional end-to-end task scoring |
app.py |
Gradio demo for live generation |
attnvq_slides.html |
Slides - includes LongBench metrics |
artifacts/ |
Pre-fit codebooks and LongBench results |
Variants: productvq-* (AttnVQ), ropesplit-1b (RoPE-half split for Laguna), scalar/KIVI/sign/ternary baselines, TurboQuant MSE/Prod.
Headline results (Laguna-XS.2)
Memory @ 131K context (full-attention layers only):
| Config | KV cache |
|---|---|
| fp16 | 5.4 GB |
AttnVQ 2-bit (productvq-32x256-2b) |
0.73 GB (7.4Γ) |
AttnVQ 1-bit (productvq-16x256-1b) |
0.40 GB (14Γ) |
LongBench v1 (mean F1 over qasper, 2wikimqa, hotpotqa, repobench-p; single 15-trace codebook):
- 2-bit: ~96% of fp16 β TurboQuant ~83%, INT2 ~75%
- 1-bit: AttnVQ and RoPESplit beat every iso-budget baseline on every task
- 0.5-bit: only VQ reaches this regime at all
Full numbers: artifacts/longbench_results.json, artifacts/longbench_cheap_metrics.json.
Note: Wall-clock speedup requires a fused dequant kernel.
Quick start
Tested on CUDA 12.4 / NVIDIA A100.
pip install "git+https://github.com/huggingface/transformers.git" \
accelerate datasets torch==2.9.1 torchvision tqdm
python generate.py
Use fitted codebooks for memory efficient long context generation:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from vqkv.compressed_cache import VQQuantizedCache
# load model
tok = AutoTokenizer.from_pretrained("poolside/Laguna-XS.2", trust_remote_code=True, fix_mistral_regex=True)
model = AutoModelForCausalLM.from_pretrained("poolside/Laguna-XS.2", torch_dtype=torch.bfloat16, device_map="cuda", trust_remote_code=True).eval()
# load codebooks or fit and use your own
CODEBOOKS_PATH = "artifacts/codebooks.pt"
codebooks = torch.load(CODEBOOKS_PATH, map_location="cuda", weights_only=False)
# build cache
quantizers, layers = codebooks["fitted"]["productvq-32x256-2b"], codebooks["meta"]["full_layers"]
cache = VQQuantizedCache(quantizers, layers) # persists uint8 codebook indices
# generate
ids = tok("Hello", return_tensors="pt").to(model.device)
out = model.generate(**ids, max_new_tokens=32, past_key_values=cache, use_cache=True)
print(tok.decode(out[0, ids["input_ids"].shape[1]:], skip_special_tokens=True))
# print memory footprint
print(cache.memory_footprint())
Reproduce
Precomputed results are under artifacts/. To re-fit and evaluate:
python benchmark.py --stage fit
python turbo_benchmark.py --stage fit
python longbench_eval.py --stage cheap --n_eval 50 # cheap metrics
python longbench_eval.py --stage generate --n_eval 50 # slow: full generation & task metrics