| --- |
| license: mit |
| tags: |
| - mechanistic-interpretability |
| - grokking |
| - modular-arithmetic |
| - transformer |
| - TransformerLens |
| - pytorch |
| - toy-model |
| language: |
| - en |
| library_name: transformers |
| pipeline_tag: text-classification |
| --- |
| |
| # Modular Multiplication Transformer |
|
|
| A 1-layer, 4-head transformer trained on **(a x b) mod 113** that exhibits **grokking** (delayed generalization after memorization). This checkpoint includes full training history (400 checkpoints across 40,000 epochs). |
|
|
| ## Model Architecture |
|
|
| | Parameter | Value | |
| |-----------|-------| |
| | Layers | 1 | |
| | Attention Heads | 4 | |
| | d_model | 128 | |
| | d_head | 32 | |
| | d_mlp | 512 | |
| | Activation | ReLU | |
| | Layer Norm | None | |
| | Vocabulary | 114 (0-112 + "=" separator) | |
| | Output Classes | 113 | |
| | Context Length | 3 tokens [a, b, =] | |
| | Trainable Parameters | ~230,000 | |
| |
| Built with [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens) (`HookedTransformer`). No layer normalization and frozen biases. |
| |
| ## Training Details |
| |
| | Parameter | Value | |
| |-----------|-------| |
| | Optimizer | AdamW | |
| | Learning Rate | 1e-3 | |
| | Weight Decay | 1.0 | |
| | Betas | (0.9, 0.98) | |
| | Epochs | 40,000 | |
| | Training Fraction | 30% (3,830 / 12,769 samples) | |
| | Batch Size | Full-batch | |
| | Data Seed | 598 | |
| | Model Seed | 999 | |
| |
| ## Checkpoint Contents |
| |
| ```python |
| checkpoint = torch.load("mod_mult_grokking.pth") |
| |
| checkpoint["model"] # Final model state_dict |
| checkpoint["config"] # HookedTransformerConfig |
| checkpoint["checkpoints"] # List of 400 state_dicts (every 100 epochs) |
| checkpoint["checkpoint_epochs"] # [0, 100, 200, ..., 39900] |
| checkpoint["train_losses"] # 40,000 training loss values |
| checkpoint["test_losses"] # 40,000 test loss values |
| checkpoint["train_accs"] # 40,000 training accuracy values |
| checkpoint["test_accs"] # 40,000 test accuracy values |
| checkpoint["train_indices"] # Indices of 3,830 training samples |
| checkpoint["test_indices"] # Indices of 8,939 test samples |
| ``` |
| |
| ## Usage |
| |
| ```python |
| import torch |
| from transformer_lens import HookedTransformer |
| |
| checkpoint = torch.load("mod_mult_grokking.pth", map_location="cpu") |
| model = HookedTransformer(checkpoint["config"]) |
| model.load_state_dict(checkpoint["model"]) |
| model.eval() |
|
|
| # Compute 7 * 16 mod 113 = 112 |
| a, b = 7, 16 |
| separator = 113 # "=" token |
| input_tokens = torch.tensor([[a, b, separator]]) |
| logits = model(input_tokens) |
| prediction = logits[0, -1].argmax().item() |
| print(f"{a} * {b} mod 113 = {prediction}") # 112 |
| ``` |
| |
| ## References |
| |
| - Nanda et al. (2023). "Progress measures for grokking via mechanistic interpretability." ICLR 2023. |
| - Power et al. (2022). "Grokking: Generalization beyond overfitting on small algorithmic datasets." |
| - [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens) |
| |