Atlas-MAG with Omega Rule

A 43M parameter implementation of the Atlas paper's Memory-As-Gate (MAG) architecture with polynomial memory and test-time learning (TTL).

Paper: Atlas: Learning to Optimally Memorize the Context at Test Time (Behrouz et al., Google Research)

Code: toddwbucy/Atlas-MAG_OmegaRule

What This Model Demonstrates

This checkpoint exists to demonstrate a concrete infrastructure problem: test-time learning models cannot be served by existing deployment stacks.

Atlas-MAG uses gradient descent during the forward pass to update its memory. PyTorch gates this behind if self.training. Every serving framework calls the equivalent of inference mode before serving. The model's memory architecture is silenced.

Two scripts in the GitHub repo let you see this firsthand.

Quick Start

git clone https://github.com/toddwbucy/Atlas-MAG_OmegaRule.git
cd Atlas-MAG_OmegaRule
pip install torch huggingface_hub tokenizers

# Demo: same model, same weights, different outputs depending on training flag
python scripts/demo_ttl_inference.py

# Benchmark: NIAH memory probe — TTL ON vs TTL OFF side by side
python scripts/benchmark_niah.py

Both scripts auto-download this checkpoint. No manual download needed.

Model Details

Architecture Atlas-MAG (Memory-As-Gate)
Parameters 43M
Dimensions dim=512, 6 layers, 8 heads
Memory Polynomial degree-2, rank-512
Attention Sliding window, size=512
TTL Muon optimizer (NS-5), theta=0.9, alpha=0.999, eta=0.01
Vocab 49,152 (SmolLM tokenizer)
Training Steps 8,800
Training Hardware 2x NVIDIA A6000 48GB
Training Data SmolLM-Corpus (cosmopedia 40%, fineweb-edu 50%, python-edu 10%)
NIAH Accuracy 85.9%
Checkpoint Size 473MB
Format PyTorch (.pt)

Architecture

Input -> Embedding -> [MAGBlock x 6] -> RMSNorm -> LM Head -> Output

MAGBlock:
    x --+--> [Sliding Window Attention] --> attn_out
        |                                      |
        +--> [Deep Polynomial Memory]  --> mem_out
                                               |
        output = x + attn_out * sigmoid(mem_out)

The polynomial feature map increases memory capacity from O(d_k) to O(d_k^2) per layer — roughly 64x more associations.

Loading

import torch
from huggingface_hub import hf_hub_download

# Download checkpoint
ckpt_path = hf_hub_download("r3d91ll/Atlas-MAG_OmegaRule", "checkpoint_step008800.pt")
checkpoint = torch.load(ckpt_path, map_location="cuda:0", weights_only=False)

# The checkpoint contains:
# - "model_state_dict": model weights
# - "config": full training configuration dict
print(checkpoint["config"])

For full model loading, see the GitHub repository which includes the model class and demo scripts.

Files

File Size Description
checkpoint_step008800.pt 473MB Model weights + config + optimizer state
tokenizer_smollm.json 2.2MB BPE tokenizer (SmolLM)

Citation

@article{behrouz2025atlas,
  title={Atlas: Learning to Optimally Memorize the Context at Test Time},
  author={Behrouz, Ali and Li, Yingcong and Kacham, Praneeth and Daliri, Poria and Deng, Zhihao and Zhong, Peilin and Razaviyayn, Meisam and Mirrokni, Vahab},
  journal={arXiv preprint arXiv:2505.23735},
  year={2025}
}

License

MIT

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train r3d91ll/Atlas-MAG_OmegaRule

Paper for r3d91ll/Atlas-MAG_OmegaRule