[Feature request] Eliminate pre-attention RMSNorm in MLA-models via scale invariance + weight folding

#247
by graefics - opened

Due to the scale invariance of RMS, an RMSNorm layer followed by a linear projection followed by another RMSNorm allows the first RMSNorm to be eliminated entirely — a mathematically lossless simplification.

For MLA-models with latent-normalization (here kv_a_layernorm and q_a_layernorm), this means the pre-attention RMSNorm can be removed with no change to model outputs, see FlashNorm paper.

Image

However, the pre-attention norm's learned weights are still needed. These can be eliminated cleanly by folding them into the QKV projection weights using the FlashNorm weight-folding trick — again with no loss in model accuracy.

Image

For reference, we have applied this weight folding trick to a few LLMs (Llama, Qwen, SMolLM) here:
https://huggingface.co/models?other=weightless-rmsnorm

Sign up or log in to comment