Atlas-MAG with Omega Rule
A 43M parameter implementation of the Atlas paper's Memory-As-Gate (MAG) architecture with polynomial memory and test-time learning (TTL).
Paper: Atlas: Learning to Optimally Memorize the Context at Test Time (Behrouz et al., Google Research)
Code: toddwbucy/Atlas-MAG_OmegaRule
What This Model Demonstrates
This checkpoint exists to demonstrate a concrete infrastructure problem: test-time learning models cannot be served by existing deployment stacks.
Atlas-MAG uses gradient descent during the forward pass to update its memory. PyTorch gates this behind if self.training. Every serving framework calls the equivalent of inference mode before serving. The model's memory architecture is silenced.
Two scripts in the GitHub repo let you see this firsthand.
Quick Start
git clone https://github.com/toddwbucy/Atlas-MAG_OmegaRule.git
cd Atlas-MAG_OmegaRule
pip install torch huggingface_hub tokenizers
# Demo: same model, same weights, different outputs depending on training flag
python scripts/demo_ttl_inference.py
# Benchmark: NIAH memory probe — TTL ON vs TTL OFF side by side
python scripts/benchmark_niah.py
Both scripts auto-download this checkpoint. No manual download needed.
Model Details
| Architecture | Atlas-MAG (Memory-As-Gate) |
| Parameters | 43M |
| Dimensions | dim=512, 6 layers, 8 heads |
| Memory | Polynomial degree-2, rank-512 |
| Attention | Sliding window, size=512 |
| TTL | Muon optimizer (NS-5), theta=0.9, alpha=0.999, eta=0.01 |
| Vocab | 49,152 (SmolLM tokenizer) |
| Training Steps | 8,800 |
| Training Hardware | 2x NVIDIA A6000 48GB |
| Training Data | SmolLM-Corpus (cosmopedia 40%, fineweb-edu 50%, python-edu 10%) |
| NIAH Accuracy | 85.9% |
| Checkpoint Size | 473MB |
| Format | PyTorch (.pt) |
Architecture
Input -> Embedding -> [MAGBlock x 6] -> RMSNorm -> LM Head -> Output
MAGBlock:
x --+--> [Sliding Window Attention] --> attn_out
| |
+--> [Deep Polynomial Memory] --> mem_out
|
output = x + attn_out * sigmoid(mem_out)
The polynomial feature map increases memory capacity from O(d_k) to O(d_k^2) per layer — roughly 64x more associations.
Loading
import torch
from huggingface_hub import hf_hub_download
# Download checkpoint
ckpt_path = hf_hub_download("r3d91ll/Atlas-MAG_OmegaRule", "checkpoint_step008800.pt")
checkpoint = torch.load(ckpt_path, map_location="cuda:0", weights_only=False)
# The checkpoint contains:
# - "model_state_dict": model weights
# - "config": full training configuration dict
print(checkpoint["config"])
For full model loading, see the GitHub repository which includes the model class and demo scripts.
Files
| File | Size | Description |
|---|---|---|
checkpoint_step008800.pt |
473MB | Model weights + config + optimizer state |
tokenizer_smollm.json |
2.2MB | BPE tokenizer (SmolLM) |
Citation
@article{behrouz2025atlas,
title={Atlas: Learning to Optimally Memorize the Context at Test Time},
author={Behrouz, Ali and Li, Yingcong and Kacham, Praneeth and Daliri, Poria and Deng, Zhihao and Zhong, Peilin and Razaviyayn, Meisam and Mirrokni, Vahab},
journal={arXiv preprint arXiv:2505.23735},
year={2025}
}
License
MIT