Upload Atom2.7m model
Browse files- .gitattributes +1 -0
- README.md +162 -0
- benchmark_fusion_arithmark.py +291 -0
- bg.png +3 -0
- config.json +29 -0
- config.py +233 -0
- configuration_gpt.py +35 -0
- lm_eval_fusion +9 -0
- lm_eval_fusion.py +299 -0
- model.py +335 -0
- model.safetensors +3 -0
- pretraining_curriculum.json +46 -0
- requirements.txt +6 -0
- special_tokens_map.json +9 -0
- tokenizer.json +0 -0
- tokenizer_config.json +11 -0
- tokenizer_utils.py +328 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
bg.png filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
pipeline_tag: text-generation
|
| 6 |
+
tags:
|
| 7 |
+
- causal-lm
|
| 8 |
+
- gpt
|
| 9 |
+
- small-language-model
|
| 10 |
+
- arithmetic
|
| 11 |
+
- custom-tokenizer
|
| 12 |
+
- custom-code
|
| 13 |
+
- safetensors
|
| 14 |
+
- lm-evaluation-harness
|
| 15 |
+
datasets:
|
| 16 |
+
- openbmb/Ultra-FineWeb
|
| 17 |
+
- HuggingFaceFW/fineweb-edu
|
| 18 |
+
- HuggingFaceTB/finemath
|
| 19 |
+
- HuggingFaceTB/smollm-corpus
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+

|
| 23 |
+
# Atom2.7m
|
| 24 |
+
|
| 25 |
+
Atom2.7m is a small decoder-only causal language model trained with a general byte-level BPE tokenizer plus arithmetic-specific digit features. The model has 2,738,880 parameters and uses custom code for both the model and the tokenizer path.
|
| 26 |
+
|
| 27 |
+
## Model Details
|
| 28 |
+
|
| 29 |
+
- Architecture: decoder-only GPT
|
| 30 |
+
- Parameters: 2,738,880
|
| 31 |
+
- Layers: 5
|
| 32 |
+
- Hidden size: 192
|
| 33 |
+
- Attention heads: 4
|
| 34 |
+
- KV heads: 2
|
| 35 |
+
- Context length: 512
|
| 36 |
+
- Vocabulary size: 4,096
|
| 37 |
+
- Token embeddings: tied input/output embeddings
|
| 38 |
+
- Arithmetic feature embeddings:
|
| 39 |
+
- `place_vocab_size`: 66
|
| 40 |
+
- `role_vocab_size`: 12
|
| 41 |
+
|
| 42 |
+
## Tokenizer
|
| 43 |
+
|
| 44 |
+
This model should not be evaluated or used with a plain Hugging Face tokenizer path alone. It uses a custom fusion tokenizer implemented in `tokenizer_utils.py`.
|
| 45 |
+
|
| 46 |
+
The tokenizer keeps byte-level BPE for ordinary text, but treats arithmetic sensitive spans specially:
|
| 47 |
+
|
| 48 |
+
- digits `0`-`9` are atomic and never BPE-merged
|
| 49 |
+
- digit spans are emitted least-significant-digit first
|
| 50 |
+
- `+ - * / = ( )` are isolated atomic tokens
|
| 51 |
+
- whitespace is isolated from text
|
| 52 |
+
- `place_ids` are assigned to every digit run
|
| 53 |
+
- `role_ids` are assigned only for strict integer equation spans
|
| 54 |
+
|
| 55 |
+
The model expects aligned `input_ids`, `place_ids`, and `role_ids`.
|
| 56 |
+
|
| 57 |
+
## Usage
|
| 58 |
+
|
| 59 |
+
```python
|
| 60 |
+
from pathlib import Path
|
| 61 |
+
|
| 62 |
+
import torch
|
| 63 |
+
from transformers import AutoModelForCausalLM
|
| 64 |
+
|
| 65 |
+
from tokenizer_utils import load_tokenizer
|
| 66 |
+
|
| 67 |
+
model_dir = Path(".")
|
| 68 |
+
|
| 69 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 70 |
+
model_dir,
|
| 71 |
+
trust_remote_code=True,
|
| 72 |
+
).eval()
|
| 73 |
+
tokenizer = load_tokenizer(model_dir)
|
| 74 |
+
|
| 75 |
+
text = "12 + 34 ="
|
| 76 |
+
encoding = tokenizer.encode(text)
|
| 77 |
+
|
| 78 |
+
input_ids = torch.tensor([encoding.input_ids])
|
| 79 |
+
place_ids = torch.tensor([encoding.place_ids])
|
| 80 |
+
role_ids = torch.tensor([encoding.role_ids])
|
| 81 |
+
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
outputs = model(
|
| 84 |
+
input_ids=input_ids,
|
| 85 |
+
place_ids=place_ids,
|
| 86 |
+
role_ids=role_ids,
|
| 87 |
+
)
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
For correct results, do not rely on `pipeline("text-generation")` unless it is wrapped to provide `place_ids` and `role_ids`.
|
| 91 |
+
|
| 92 |
+
## Evaluation
|
| 93 |
+
|
| 94 |
+
### ArithMark 2.0
|
| 95 |
+
|
| 96 |
+
Use the included fusion-aware benchmark script:
|
| 97 |
+
|
| 98 |
+
```bash
|
| 99 |
+
python benchmark_fusion_arithmark.py \
|
| 100 |
+
--checkpoint . \
|
| 101 |
+
--tokenizer-dir . \
|
| 102 |
+
--data-path arithmark_2.0.jsonl \
|
| 103 |
+
--batch-size 64 \
|
| 104 |
+
--device cuda \
|
| 105 |
+
--output benchmark_results/fusion_arithmark_2.0_results.json
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
### lm-evaluation-harness
|
| 109 |
+
|
| 110 |
+
Use the included launcher so the `atom2.7m` model wrapper is registered:
|
| 111 |
+
|
| 112 |
+
```bash
|
| 113 |
+
python lm_eval_fusion run \
|
| 114 |
+
--model atom2.7m \
|
| 115 |
+
--model_args pretrained=.,tokenizer_dir=. \
|
| 116 |
+
--tasks hellaswag,arc_easy,arc_challenge,piqa \
|
| 117 |
+
--device cuda:0 \
|
| 118 |
+
--batch_size auto \
|
| 119 |
+
--output_path benchmark_results/lm_eval
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
The wrapper uses `tokenizer_utils.load_tokenizer()` and forwards `place_ids` and `role_ids` to the model.
|
| 123 |
+
|
| 124 |
+
## Results
|
| 125 |
+
|
| 126 |
+
| Benchmark | Metric | Value |
|
| 127 |
+
| --- | --- | ---: |
|
| 128 |
+
| ArithMark 2.0 | acc | 0.6380 |
|
| 129 |
+
| arc_challenge | acc_norm | 0.2261 |
|
| 130 |
+
| arc_easy | acc_norm | 0.3270 |
|
| 131 |
+
| hellaswag | acc_norm | 0.2733 |
|
| 132 |
+
| piqa | acc_norm | 0.5305 |
|
| 133 |
+
|
| 134 |
+
## Training Data
|
| 135 |
+
|
| 136 |
+
The pretraining mixture targeted about 3.5B tokens:
|
| 137 |
+
|
| 138 |
+
- Ultra-FineWeb: 900M
|
| 139 |
+
- FineWeb-Edu: 900M
|
| 140 |
+
- FineMath: 450M
|
| 141 |
+
- Cosmopedia-v2: 337.5M
|
| 142 |
+
- UltraData-Math-L2-preview: 337.5M
|
| 143 |
+
- Ultra-FineWeb-L3-en-QA-Synthetic: 225M
|
| 144 |
+
- Synthetic-Arithmetic: 350M
|
| 145 |
+
|
| 146 |
+
Synthetic-Arithmetic is AtomCalc-style canonical integer equation data. The training curriculum is included as `pretraining_curriculum.json`.
|
| 147 |
+
|
| 148 |
+
## Limitations
|
| 149 |
+
|
| 150 |
+
- This is a very small model and should be treated as an experimental research artifact.
|
| 151 |
+
- The custom tokenizer makes plain `AutoTokenizer` or default `lm_eval --model hf` unsuitable for final reported numbers.
|
| 152 |
+
- Numeric text is represented least-significant-digit first internally.
|
| 153 |
+
- Role annotations intentionally target strict integer equations, not broad math prose, decimals, rationals, or QA formats.
|
| 154 |
+
|
| 155 |
+
## Files
|
| 156 |
+
|
| 157 |
+
- `model.safetensors`: model weights
|
| 158 |
+
- `config.json`, `config.py`, `configuration_gpt.py`, `model.py`: custom model code
|
| 159 |
+
- `tokenizer.json`, `tokenizer_utils.py`: tokenizer files and fusion wrapper
|
| 160 |
+
- `benchmark_fusion_arithmark.py`: ArithMark evaluation
|
| 161 |
+
- `lm_eval_fusion.py`, `lm_eval_fusion`: lm-eval custom model wrapper
|
| 162 |
+
- `pretraining_curriculum.json`: training curriculum
|
benchmark_fusion_arithmark.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Score an Atom2.7m checkpoint on ArithMark 2.0."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
from collections import Counter
|
| 7 |
+
from contextlib import nullcontext
|
| 8 |
+
import json
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import re
|
| 11 |
+
import urllib.request
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from transformers import AutoModelForCausalLM
|
| 16 |
+
|
| 17 |
+
from tokenizer_utils import SPECIAL_TOKENS, FusionTokenizer, load_tokenizer
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
DATA_URL = (
|
| 21 |
+
"https://huggingface.co/datasets/AxiomicLabs/Arithmark-2.0/"
|
| 22 |
+
"resolve/main/arithmark_2.0.jsonl"
|
| 23 |
+
)
|
| 24 |
+
PAD_ID = SPECIAL_TOKENS.index("<|pad|>")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def ensure_data(path: Path) -> Path:
|
| 28 |
+
if path.exists():
|
| 29 |
+
return path
|
| 30 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 31 |
+
urllib.request.urlretrieve(DATA_URL, path)
|
| 32 |
+
return path
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def load_examples(path: Path, *, max_examples: int = 0) -> list[dict]:
|
| 36 |
+
examples = []
|
| 37 |
+
with path.open("r", encoding="utf-8") as handle:
|
| 38 |
+
for line in handle:
|
| 39 |
+
if not line.strip():
|
| 40 |
+
continue
|
| 41 |
+
examples.append(json.loads(line))
|
| 42 |
+
if max_examples > 0 and len(examples) >= max_examples:
|
| 43 |
+
break
|
| 44 |
+
return examples
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _encoded_choice(
|
| 48 |
+
tokenizer: FusionTokenizer,
|
| 49 |
+
context: str,
|
| 50 |
+
ending: str,
|
| 51 |
+
) -> tuple[list[int], list[int], list[int], int]:
|
| 52 |
+
context_encoding = tokenizer.encode(context)
|
| 53 |
+
full_encoding = tokenizer.encode(context + ending)
|
| 54 |
+
continuation_length = len(full_encoding.input_ids) - len(context_encoding.input_ids)
|
| 55 |
+
return (
|
| 56 |
+
full_encoding.input_ids,
|
| 57 |
+
full_encoding.place_ids,
|
| 58 |
+
full_encoding.role_ids,
|
| 59 |
+
continuation_length,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@torch.inference_mode()
|
| 64 |
+
def evaluate(
|
| 65 |
+
model: torch.nn.Module,
|
| 66 |
+
tokenizer: FusionTokenizer,
|
| 67 |
+
examples: list[dict],
|
| 68 |
+
*,
|
| 69 |
+
device: torch.device,
|
| 70 |
+
batch_size: int,
|
| 71 |
+
dump_failures: bool = False,
|
| 72 |
+
failure_operator_count: int | None = None,
|
| 73 |
+
max_failures: int = 100,
|
| 74 |
+
) -> dict:
|
| 75 |
+
correct = 0
|
| 76 |
+
total = 0
|
| 77 |
+
by_operator_count: dict[str, list[int]] = {}
|
| 78 |
+
by_topic: dict[str, list[int]] = {}
|
| 79 |
+
failures: list[dict] = []
|
| 80 |
+
failure_summary: Counter[tuple[str, str, str]] = Counter()
|
| 81 |
+
model.eval()
|
| 82 |
+
|
| 83 |
+
for start in range(0, len(examples), batch_size):
|
| 84 |
+
batch_examples = examples[start : start + batch_size]
|
| 85 |
+
encoded = []
|
| 86 |
+
offsets = []
|
| 87 |
+
for example in batch_examples:
|
| 88 |
+
flat_start = len(encoded)
|
| 89 |
+
encoded.extend(
|
| 90 |
+
_encoded_choice(tokenizer, example["ctx"], ending)
|
| 91 |
+
for ending in example["endings"]
|
| 92 |
+
)
|
| 93 |
+
offsets.append((flat_start, len(example["endings"])))
|
| 94 |
+
|
| 95 |
+
max_length = max(len(item[0]) for item in encoded)
|
| 96 |
+
input_ids = torch.full(
|
| 97 |
+
(len(encoded), max_length),
|
| 98 |
+
PAD_ID,
|
| 99 |
+
dtype=torch.long,
|
| 100 |
+
device=device,
|
| 101 |
+
)
|
| 102 |
+
place_ids = torch.zeros_like(input_ids)
|
| 103 |
+
role_ids = torch.zeros_like(input_ids)
|
| 104 |
+
attention_mask = torch.zeros_like(input_ids, dtype=torch.bool)
|
| 105 |
+
lengths = []
|
| 106 |
+
continuation_lengths = []
|
| 107 |
+
for row, (ids, places, roles, continuation_length) in enumerate(encoded):
|
| 108 |
+
length = len(ids)
|
| 109 |
+
input_ids[row, :length] = torch.tensor(ids, device=device)
|
| 110 |
+
place_ids[row, :length] = torch.tensor(places, device=device)
|
| 111 |
+
role_ids[row, :length] = torch.tensor(roles, device=device)
|
| 112 |
+
attention_mask[row, :length] = True
|
| 113 |
+
lengths.append(length)
|
| 114 |
+
continuation_lengths.append(continuation_length)
|
| 115 |
+
|
| 116 |
+
autocast = (
|
| 117 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
| 118 |
+
if device.type == "cuda"
|
| 119 |
+
else nullcontext()
|
| 120 |
+
)
|
| 121 |
+
with autocast:
|
| 122 |
+
logits = model(
|
| 123 |
+
input_ids=input_ids,
|
| 124 |
+
place_ids=place_ids,
|
| 125 |
+
role_ids=role_ids,
|
| 126 |
+
attention_mask=attention_mask,
|
| 127 |
+
).logits
|
| 128 |
+
log_probs = F.log_softmax(logits.float(), dim=-1)
|
| 129 |
+
|
| 130 |
+
for example_index, example in enumerate(batch_examples):
|
| 131 |
+
flat_start, choice_count = offsets[example_index]
|
| 132 |
+
likelihoods = []
|
| 133 |
+
for choice_index in range(choice_count):
|
| 134 |
+
row = flat_start + choice_index
|
| 135 |
+
length = lengths[row]
|
| 136 |
+
continuation_length = continuation_lengths[row]
|
| 137 |
+
continuation_start = length - continuation_length
|
| 138 |
+
likelihood = 0.0
|
| 139 |
+
for position in range(continuation_start, length):
|
| 140 |
+
likelihood += float(
|
| 141 |
+
log_probs[row, position - 1, input_ids[row, position]].item()
|
| 142 |
+
)
|
| 143 |
+
likelihoods.append(likelihood)
|
| 144 |
+
|
| 145 |
+
prediction = max(range(choice_count), key=likelihoods.__getitem__)
|
| 146 |
+
label = int(example["label"])
|
| 147 |
+
matched = prediction == label
|
| 148 |
+
correct += int(matched)
|
| 149 |
+
total += 1
|
| 150 |
+
metadata = example.get("metadata", {})
|
| 151 |
+
operator_count = str(metadata.get("operator_count", "unknown"))
|
| 152 |
+
topic = str(metadata.get("topic", "unknown"))
|
| 153 |
+
for grouped, key in (
|
| 154 |
+
(by_operator_count, operator_count),
|
| 155 |
+
(by_topic, topic),
|
| 156 |
+
):
|
| 157 |
+
group = grouped.setdefault(key, [0, 0])
|
| 158 |
+
group[0] += int(matched)
|
| 159 |
+
group[1] += 1
|
| 160 |
+
|
| 161 |
+
if not matched and dump_failures:
|
| 162 |
+
op_count_int = None
|
| 163 |
+
try:
|
| 164 |
+
op_count_int = int(operator_count)
|
| 165 |
+
except ValueError:
|
| 166 |
+
pass
|
| 167 |
+
if failure_operator_count is None or op_count_int == failure_operator_count:
|
| 168 |
+
context = str(example["ctx"]).strip()
|
| 169 |
+
expression = context[:-1].strip() if context.endswith("=") else context
|
| 170 |
+
operands = [int(value) for value in re.findall(r"\d+", expression)]
|
| 171 |
+
operator = "".join(re.findall(r"[+\-*/]", expression))
|
| 172 |
+
predicted_answer = str(example["endings"][prediction]).strip()
|
| 173 |
+
correct_answer = str(example["endings"][label]).strip()
|
| 174 |
+
width = max((len(str(value)) for value in operands), default=0)
|
| 175 |
+
failure_summary[(topic, operator, f"width={width}")] += 1
|
| 176 |
+
if len(failures) < max_failures:
|
| 177 |
+
failures.append(
|
| 178 |
+
{
|
| 179 |
+
"ctx": context,
|
| 180 |
+
"topic": topic,
|
| 181 |
+
"operator_count": operator_count,
|
| 182 |
+
"operator": operator,
|
| 183 |
+
"operands": operands,
|
| 184 |
+
"max_operand_digits": width,
|
| 185 |
+
"correct_answer": correct_answer,
|
| 186 |
+
"predicted_answer": predicted_answer,
|
| 187 |
+
"choices": [str(value).strip() for value in example["endings"]],
|
| 188 |
+
"choice_scores": [round(value, 4) for value in likelihoods],
|
| 189 |
+
"score_margin_correct_minus_predicted": round(
|
| 190 |
+
likelihoods[label] - likelihoods[prediction],
|
| 191 |
+
4,
|
| 192 |
+
),
|
| 193 |
+
}
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
results = {
|
| 197 |
+
"benchmark": "arithmark_2.0",
|
| 198 |
+
"model_type": "atom2.7m",
|
| 199 |
+
"accuracy": correct / max(total, 1),
|
| 200 |
+
"correct": correct,
|
| 201 |
+
"total": total,
|
| 202 |
+
"by_operator_count": {
|
| 203 |
+
key: {
|
| 204 |
+
"accuracy": values[0] / max(values[1], 1),
|
| 205 |
+
"correct": values[0],
|
| 206 |
+
"total": values[1],
|
| 207 |
+
}
|
| 208 |
+
for key, values in sorted(by_operator_count.items())
|
| 209 |
+
},
|
| 210 |
+
"by_topic": {
|
| 211 |
+
key: {
|
| 212 |
+
"accuracy": values[0] / max(values[1], 1),
|
| 213 |
+
"correct": values[0],
|
| 214 |
+
"total": values[1],
|
| 215 |
+
}
|
| 216 |
+
for key, values in sorted(by_topic.items())
|
| 217 |
+
},
|
| 218 |
+
}
|
| 219 |
+
if dump_failures:
|
| 220 |
+
results["failure_summary"] = {
|
| 221 |
+
"|".join(key): value
|
| 222 |
+
for key, value in failure_summary.most_common()
|
| 223 |
+
}
|
| 224 |
+
results["failures"] = failures
|
| 225 |
+
return results
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def parse_args() -> argparse.Namespace:
|
| 229 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
| 230 |
+
parser.add_argument("--checkpoint", type=Path, default=Path("outputs/fusion_run/final_model"))
|
| 231 |
+
parser.add_argument("--tokenizer-dir", type=Path, default=Path("tokenizer_4k"))
|
| 232 |
+
parser.add_argument("--data-path", type=Path, default=Path("arithmark_2.0.jsonl"))
|
| 233 |
+
parser.add_argument("--batch-size", type=int, default=64)
|
| 234 |
+
parser.add_argument("--device", default="auto")
|
| 235 |
+
parser.add_argument("--output", type=Path)
|
| 236 |
+
parser.add_argument(
|
| 237 |
+
"--max-examples",
|
| 238 |
+
type=int,
|
| 239 |
+
default=0,
|
| 240 |
+
help="Evaluate only the first N examples. Default evaluates all examples.",
|
| 241 |
+
)
|
| 242 |
+
parser.add_argument(
|
| 243 |
+
"--dump-failures",
|
| 244 |
+
action="store_true",
|
| 245 |
+
help="Include incorrectly scored examples and grouped failure summary.",
|
| 246 |
+
)
|
| 247 |
+
parser.add_argument(
|
| 248 |
+
"--failure-operator-count",
|
| 249 |
+
type=int,
|
| 250 |
+
default=None,
|
| 251 |
+
help="Only dump failures with this operator count, e.g. 1 for easy examples.",
|
| 252 |
+
)
|
| 253 |
+
parser.add_argument("--max-failures", type=int, default=100)
|
| 254 |
+
return parser.parse_args()
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def main() -> None:
|
| 258 |
+
args = parse_args()
|
| 259 |
+
if args.device == "auto":
|
| 260 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 261 |
+
else:
|
| 262 |
+
device = torch.device(args.device)
|
| 263 |
+
|
| 264 |
+
data_path = ensure_data(args.data_path)
|
| 265 |
+
examples = load_examples(data_path, max_examples=args.max_examples)
|
| 266 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 267 |
+
args.checkpoint,
|
| 268 |
+
trust_remote_code=True,
|
| 269 |
+
).to(device)
|
| 270 |
+
tokenizer = load_tokenizer(args.tokenizer_dir)
|
| 271 |
+
results = evaluate(
|
| 272 |
+
model,
|
| 273 |
+
tokenizer,
|
| 274 |
+
examples,
|
| 275 |
+
device=device,
|
| 276 |
+
batch_size=args.batch_size,
|
| 277 |
+
dump_failures=args.dump_failures,
|
| 278 |
+
failure_operator_count=args.failure_operator_count,
|
| 279 |
+
max_failures=args.max_failures,
|
| 280 |
+
)
|
| 281 |
+
print(json.dumps(results, indent=2, sort_keys=True))
|
| 282 |
+
if args.output:
|
| 283 |
+
args.output.parent.mkdir(parents=True, exist_ok=True)
|
| 284 |
+
args.output.write_text(
|
| 285 |
+
json.dumps(results, indent=2, sort_keys=True) + "\n",
|
| 286 |
+
encoding="utf-8",
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
if __name__ == "__main__":
|
| 291 |
+
main()
|
bg.png
ADDED
|
Git LFS Details
|
config.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"GPTForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "config.GPTConfig",
|
| 7 |
+
"AutoModelForCausalLM": "model.GPTForCausalLM"
|
| 8 |
+
},
|
| 9 |
+
"block_size": 512,
|
| 10 |
+
"dtype": "float32",
|
| 11 |
+
"head_dim": 48,
|
| 12 |
+
"hidden_size": 192,
|
| 13 |
+
"intermediate_size": 480,
|
| 14 |
+
"labels_are_shifted": true,
|
| 15 |
+
"max_position_embeddings": 512,
|
| 16 |
+
"model_type": "gpt",
|
| 17 |
+
"num_attention_heads": 4,
|
| 18 |
+
"num_hidden_layers": 5,
|
| 19 |
+
"num_key_value_heads": 2,
|
| 20 |
+
"place_vocab_size": 66,
|
| 21 |
+
"rms_norm_eps": 1e-06,
|
| 22 |
+
"role_vocab_size": 12,
|
| 23 |
+
"rope_theta": 5000.0,
|
| 24 |
+
"transformers_version": "4.57.6",
|
| 25 |
+
"use_place_embeddings": true,
|
| 26 |
+
"use_role_embeddings": true,
|
| 27 |
+
"vocab_size": 4096,
|
| 28 |
+
"xsa_projection": true
|
| 29 |
+
}
|
config.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Environment-driven training configuration."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import math
|
| 7 |
+
import uuid
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
from transformers import PretrainedConfig
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
DEFAULT_VOCAB_SIZE = 4096
|
| 15 |
+
DEFAULT_HIDDEN_SIZE = 192
|
| 16 |
+
DEFAULT_NUM_HIDDEN_LAYERS = 5
|
| 17 |
+
DEFAULT_NUM_ATTENTION_HEADS = 4
|
| 18 |
+
DEFAULT_NUM_KEY_VALUE_HEADS = 2
|
| 19 |
+
DEFAULT_HEAD_DIM = DEFAULT_HIDDEN_SIZE // DEFAULT_NUM_ATTENTION_HEADS
|
| 20 |
+
DEFAULT_INTERMEDIATE_SIZE = DEFAULT_HIDDEN_SIZE * 5 // 2
|
| 21 |
+
DEFAULT_BLOCK_SIZE = 512
|
| 22 |
+
DEFAULT_ROPE_THETA = 5000.0
|
| 23 |
+
DEFAULT_PLACE_VOCAB_SIZE = 66
|
| 24 |
+
DEFAULT_ROLE_VOCAB_SIZE = 12
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class GPTConfig(PretrainedConfig):
|
| 28 |
+
model_type = "gpt"
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
vocab_size: int = DEFAULT_VOCAB_SIZE,
|
| 33 |
+
hidden_size: int = DEFAULT_HIDDEN_SIZE,
|
| 34 |
+
num_hidden_layers: int = DEFAULT_NUM_HIDDEN_LAYERS,
|
| 35 |
+
num_attention_heads: int = DEFAULT_NUM_ATTENTION_HEADS,
|
| 36 |
+
num_key_value_heads: int | None = DEFAULT_NUM_KEY_VALUE_HEADS,
|
| 37 |
+
intermediate_size: int | None = DEFAULT_INTERMEDIATE_SIZE,
|
| 38 |
+
head_dim: int | None = None,
|
| 39 |
+
block_size: int = DEFAULT_BLOCK_SIZE,
|
| 40 |
+
rope_theta: float = DEFAULT_ROPE_THETA,
|
| 41 |
+
rms_norm_eps: float = 1e-6,
|
| 42 |
+
xsa_projection: bool = True,
|
| 43 |
+
tie_word_embeddings: bool = True,
|
| 44 |
+
labels_are_shifted: bool = False,
|
| 45 |
+
use_place_embeddings: bool = True,
|
| 46 |
+
use_role_embeddings: bool = True,
|
| 47 |
+
place_vocab_size: int = DEFAULT_PLACE_VOCAB_SIZE,
|
| 48 |
+
role_vocab_size: int = DEFAULT_ROLE_VOCAB_SIZE,
|
| 49 |
+
**kwargs,
|
| 50 |
+
):
|
| 51 |
+
if num_key_value_heads is None:
|
| 52 |
+
num_key_value_heads = num_attention_heads
|
| 53 |
+
if head_dim is None:
|
| 54 |
+
if hidden_size % num_attention_heads != 0:
|
| 55 |
+
raise ValueError("hidden_size must be divisible by num_attention_heads")
|
| 56 |
+
head_dim = hidden_size // num_attention_heads
|
| 57 |
+
if intermediate_size is None:
|
| 58 |
+
intermediate_size = hidden_size * 4
|
| 59 |
+
if num_attention_heads % num_key_value_heads != 0:
|
| 60 |
+
raise ValueError("num_attention_heads must be divisible by num_key_value_heads")
|
| 61 |
+
if head_dim % 2 != 0:
|
| 62 |
+
raise ValueError("head_dim must be even for RoPE")
|
| 63 |
+
|
| 64 |
+
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
| 65 |
+
self.vocab_size = int(vocab_size)
|
| 66 |
+
self.hidden_size = int(hidden_size)
|
| 67 |
+
self.num_hidden_layers = int(num_hidden_layers)
|
| 68 |
+
self.num_attention_heads = int(num_attention_heads)
|
| 69 |
+
self.num_key_value_heads = int(num_key_value_heads)
|
| 70 |
+
self.intermediate_size = int(intermediate_size)
|
| 71 |
+
self.head_dim = int(head_dim)
|
| 72 |
+
self.block_size = int(block_size)
|
| 73 |
+
self.max_position_embeddings = int(block_size)
|
| 74 |
+
self.rope_theta = float(rope_theta)
|
| 75 |
+
self.rms_norm_eps = float(rms_norm_eps)
|
| 76 |
+
self.xsa_projection = bool(xsa_projection)
|
| 77 |
+
self.labels_are_shifted = bool(labels_are_shifted)
|
| 78 |
+
self.use_place_embeddings = bool(use_place_embeddings)
|
| 79 |
+
self.use_role_embeddings = bool(use_role_embeddings)
|
| 80 |
+
self.place_vocab_size = int(place_vocab_size)
|
| 81 |
+
self.role_vocab_size = int(role_vocab_size)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _bool_env(name: str, default: bool) -> bool:
|
| 85 |
+
raw = os.environ.get(name)
|
| 86 |
+
if raw is None:
|
| 87 |
+
return default
|
| 88 |
+
return raw.strip().lower() in {"1", "true", "yes", "on"}
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _path_env(name: str, default: str) -> str:
|
| 92 |
+
return str(Path(os.environ.get(name, default)).expanduser())
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@dataclass
|
| 96 |
+
class Hyperparameters:
|
| 97 |
+
data_dir: str = field(default_factory=lambda: _path_env("DATA_DIR", "."))
|
| 98 |
+
tokenized_dir: str = field(default_factory=lambda: _path_env("TOKENIZED_DIR", "tokenized2"))
|
| 99 |
+
tokenizer_dir: str = field(default_factory=lambda: _path_env("TOKENIZER_DIR", "tokenizer_4k"))
|
| 100 |
+
tokenizer_path: str = field(default_factory=lambda: os.environ.get("TOKENIZER_PATH", ""))
|
| 101 |
+
curriculum_path: str = field(default_factory=lambda: os.environ.get("CURRICULUM_PATH", ""))
|
| 102 |
+
mix_weights_path: str = field(default_factory=lambda: os.environ.get("MIX_WEIGHTS_PATH", ""))
|
| 103 |
+
run_id: str = field(default_factory=lambda: os.environ.get("RUN_ID", str(uuid.uuid4())))
|
| 104 |
+
seed: int = field(default_factory=lambda: int(os.environ.get("SEED", "1337")))
|
| 105 |
+
rank: int = field(init=False)
|
| 106 |
+
|
| 107 |
+
iterations: int = field(default_factory=lambda: int(os.environ.get("ITERATIONS", "10000")))
|
| 108 |
+
requested_train_tokens: int = field(init=False)
|
| 109 |
+
train_tokens: int = field(init=False)
|
| 110 |
+
decay_start_frac: float = field(default_factory=lambda: float(os.environ.get("DECAY_START_FRAC", "0.7")))
|
| 111 |
+
warmup_steps: int = field(default_factory=lambda: int(os.environ.get("WARMUP_STEPS", "0")))
|
| 112 |
+
lr_warmup_steps: int = field(default_factory=lambda: int(os.environ.get("LR_WARMUP_STEPS", "500")))
|
| 113 |
+
train_batch_tokens: int = field(default_factory=lambda: int(os.environ.get("TRAIN_BATCH_TOKENS", str(DEFAULT_BLOCK_SIZE * 512))))
|
| 114 |
+
train_seq_len: int = field(init=False)
|
| 115 |
+
eval_seq_len: int = field(init=False)
|
| 116 |
+
grad_accum_steps: int = field(default_factory=lambda: int(os.environ.get("GRAD_ACCUM_STEPS", "4")))
|
| 117 |
+
train_log_every: int = field(default_factory=lambda: int(os.environ.get("TRAIN_LOG_EVERY", "100")))
|
| 118 |
+
train_log_dense_steps: int = field(default_factory=lambda: int(os.environ.get("TRAIN_LOG_DENSE_STEPS", "100")))
|
| 119 |
+
train_log_ramp_steps: int = field(
|
| 120 |
+
default_factory=lambda: int(
|
| 121 |
+
os.environ.get(
|
| 122 |
+
"TRAIN_LOG_RAMP_STEPS",
|
| 123 |
+
os.environ.get("TRAIN_LOG_FIRST_STEPS", "500"),
|
| 124 |
+
)
|
| 125 |
+
)
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
val_batch_tokens: int = field(default_factory=lambda: int(os.environ.get("VAL_BATCH_TOKENS", str(DEFAULT_BLOCK_SIZE * 256))))
|
| 129 |
+
val_loss_every: int = field(default_factory=lambda: int(os.environ.get("VAL_LOSS_EVERY", "1000")))
|
| 130 |
+
val_max_tokens: int = field(default_factory=lambda: int(os.environ.get("VAL_MAX_TOKENS", "10_000_000")))
|
| 131 |
+
|
| 132 |
+
vocab_size: int = field(default_factory=lambda: int(os.environ.get("VOCAB_SIZE", str(DEFAULT_VOCAB_SIZE))))
|
| 133 |
+
hidden_size: int = field(default_factory=lambda: int(os.environ.get("HIDDEN_SIZE", os.environ.get("MODEL_DIM", str(DEFAULT_HIDDEN_SIZE)))))
|
| 134 |
+
num_hidden_layers: int = field(default_factory=lambda: int(os.environ.get("NUM_HIDDEN_LAYERS", os.environ.get("NUM_LAYERS", str(DEFAULT_NUM_HIDDEN_LAYERS)))))
|
| 135 |
+
num_attention_heads: int = field(default_factory=lambda: int(os.environ.get("NUM_ATTENTION_HEADS", os.environ.get("NUM_HEADS", str(DEFAULT_NUM_ATTENTION_HEADS)))))
|
| 136 |
+
num_key_value_heads: int = field(default_factory=lambda: int(os.environ.get("NUM_KEY_VALUE_HEADS", os.environ.get("NUM_KV_HEADS", str(DEFAULT_NUM_KEY_VALUE_HEADS)))))
|
| 137 |
+
head_dim: int = field(init=False)
|
| 138 |
+
intermediate_size: int = field(default_factory=lambda: int(os.environ.get("INTERMEDIATE_SIZE", os.environ.get("INTERMEDIATE", str(DEFAULT_INTERMEDIATE_SIZE)))))
|
| 139 |
+
block_size: int = field(default_factory=lambda: int(os.environ.get("BLOCK_SIZE", str(DEFAULT_BLOCK_SIZE))))
|
| 140 |
+
rope_theta: float = field(default_factory=lambda: float(os.environ.get("ROPE_THETA", os.environ.get("ROPE_BASE", str(DEFAULT_ROPE_THETA)))))
|
| 141 |
+
rms_norm_eps: float = field(default_factory=lambda: float(os.environ.get("RMS_NORM_EPS", "1e-6")))
|
| 142 |
+
xsa_projection: bool = field(default_factory=lambda: _bool_env("XSA_PROJECTION", True))
|
| 143 |
+
tie_word_embeddings: bool = field(default_factory=lambda: _bool_env("TIE_WORD_EMBEDDINGS", _bool_env("TIE_EMBEDDINGS", True)))
|
| 144 |
+
use_place_embeddings: bool = field(default_factory=lambda: _bool_env("USE_PLACE_EMBEDDINGS", True))
|
| 145 |
+
use_role_embeddings: bool = field(default_factory=lambda: _bool_env("USE_ROLE_EMBEDDINGS", True))
|
| 146 |
+
place_vocab_size: int = field(default_factory=lambda: int(os.environ.get("PLACE_VOCAB_SIZE", str(DEFAULT_PLACE_VOCAB_SIZE))))
|
| 147 |
+
role_vocab_size: int = field(default_factory=lambda: int(os.environ.get("ROLE_VOCAB_SIZE", str(DEFAULT_ROLE_VOCAB_SIZE))))
|
| 148 |
+
|
| 149 |
+
min_lr: float = field(default_factory=lambda: float(os.environ.get("MIN_LR", "0.0")))
|
| 150 |
+
lr: float = field(default_factory=lambda: float(os.environ.get("LR", "0.004")))
|
| 151 |
+
beta1: float = field(default_factory=lambda: float(os.environ.get("BETA1", "0.9")))
|
| 152 |
+
beta2: float = field(default_factory=lambda: float(os.environ.get("BETA2", "0.95")))
|
| 153 |
+
adam_eps: float = field(default_factory=lambda: float(os.environ.get("ADAM_EPS", "1e-8")))
|
| 154 |
+
weight_decay: float = field(default_factory=lambda: float(os.environ.get("WEIGHT_DECAY", "0.005")))
|
| 155 |
+
|
| 156 |
+
compile_model: bool = field(default_factory=lambda: _bool_env("COMPILE_MODEL", True))
|
| 157 |
+
autocast: bool = field(default_factory=lambda: _bool_env("AUTOCAST", True))
|
| 158 |
+
bf16: bool = field(default_factory=lambda: _bool_env("BF16", True))
|
| 159 |
+
device: str = field(default_factory=lambda: os.environ.get("DEVICE", "auto"))
|
| 160 |
+
|
| 161 |
+
output_dir: str = field(default_factory=lambda: _path_env("OUTPUT_DIR", "outputs"))
|
| 162 |
+
checkpoint_name: str = field(default_factory=lambda: os.environ.get("CHECKPOINT_NAME", "final_model"))
|
| 163 |
+
logfile: str = field(init=False)
|
| 164 |
+
model_path: str = field(init=False)
|
| 165 |
+
is_main_process: bool = True
|
| 166 |
+
train_files: str = field(init=False)
|
| 167 |
+
val_files: str = field(init=False)
|
| 168 |
+
|
| 169 |
+
def __post_init__(self) -> None:
|
| 170 |
+
self.rank = int(os.environ.get("RANK", "0"))
|
| 171 |
+
if self.rank < 0:
|
| 172 |
+
raise ValueError("RANK must be non-negative")
|
| 173 |
+
self.is_main_process = self.rank == 0
|
| 174 |
+
self.head_dim = int(os.environ.get("HEAD_DIM", str(self.hidden_size // self.num_attention_heads)))
|
| 175 |
+
self.train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", str(self.block_size)))
|
| 176 |
+
self.eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", os.environ.get("TRAIN_SEQ_LEN", str(self.train_seq_len))))
|
| 177 |
+
token_alignment = self.train_seq_len * self.grad_accum_steps
|
| 178 |
+
if self.train_batch_tokens % token_alignment != 0:
|
| 179 |
+
raise ValueError(
|
| 180 |
+
"TRAIN_BATCH_TOKENS must be divisible by TRAIN_SEQ_LEN * GRAD_ACCUM_STEPS"
|
| 181 |
+
)
|
| 182 |
+
requested_train_tokens = int(os.environ.get("TRAIN_TOKENS", "0"))
|
| 183 |
+
self.requested_train_tokens = requested_train_tokens or self.iterations * self.train_batch_tokens
|
| 184 |
+
if self.requested_train_tokens <= 0:
|
| 185 |
+
raise ValueError("TRAIN_TOKENS must be positive")
|
| 186 |
+
self.train_tokens = self.requested_train_tokens - (self.requested_train_tokens % token_alignment)
|
| 187 |
+
if self.train_tokens <= 0:
|
| 188 |
+
raise ValueError(
|
| 189 |
+
"TRAIN_TOKENS must provide at least TRAIN_SEQ_LEN * GRAD_ACCUM_STEPS tokens"
|
| 190 |
+
)
|
| 191 |
+
self.iterations = math.ceil(self.train_tokens / self.train_batch_tokens)
|
| 192 |
+
tokenized = Path(self.tokenized_dir)
|
| 193 |
+
self.train_files = os.environ.get("TRAIN_FILES", str(tokenized / "*" / "shard_*.bin"))
|
| 194 |
+
self.val_files = os.environ.get("VAL_FILES", os.environ.get("TRAIN_FILES", self.train_files))
|
| 195 |
+
explicit_legacy_mix = bool(os.environ.get("MIX_WEIGHTS_PATH"))
|
| 196 |
+
if not self.curriculum_path and not explicit_legacy_mix:
|
| 197 |
+
tokenized_curriculum = tokenized / "curriculum.json"
|
| 198 |
+
default_curriculum = Path("pretraining_curriculum.json")
|
| 199 |
+
if tokenized_curriculum.exists():
|
| 200 |
+
self.curriculum_path = str(tokenized_curriculum)
|
| 201 |
+
elif default_curriculum.exists():
|
| 202 |
+
self.curriculum_path = str(default_curriculum)
|
| 203 |
+
if not self.mix_weights_path and not self.curriculum_path:
|
| 204 |
+
mix_weights_path = tokenized / "mix_weights.json"
|
| 205 |
+
self.mix_weights_path = str(mix_weights_path) if mix_weights_path.exists() else ""
|
| 206 |
+
if not self.tokenizer_path:
|
| 207 |
+
tok_dir = Path(self.tokenizer_dir)
|
| 208 |
+
json_path = tok_dir / "tokenizer.json"
|
| 209 |
+
self.tokenizer_path = str(json_path if json_path.exists() else tok_dir)
|
| 210 |
+
out = Path(self.output_dir)
|
| 211 |
+
self.logfile = os.environ.get("LOGFILE", str(out / "logs" / f"{self.run_id}.txt"))
|
| 212 |
+
self.model_path = os.environ.get("MODEL_PATH", str(out / self.checkpoint_name))
|
| 213 |
+
|
| 214 |
+
def to_model_config(self) -> GPTConfig:
|
| 215 |
+
return GPTConfig(
|
| 216 |
+
vocab_size=self.vocab_size,
|
| 217 |
+
hidden_size=self.hidden_size,
|
| 218 |
+
num_hidden_layers=self.num_hidden_layers,
|
| 219 |
+
num_attention_heads=self.num_attention_heads,
|
| 220 |
+
num_key_value_heads=self.num_key_value_heads,
|
| 221 |
+
head_dim=self.head_dim,
|
| 222 |
+
intermediate_size=self.intermediate_size,
|
| 223 |
+
block_size=self.block_size,
|
| 224 |
+
rope_theta=self.rope_theta,
|
| 225 |
+
rms_norm_eps=self.rms_norm_eps,
|
| 226 |
+
xsa_projection=self.xsa_projection,
|
| 227 |
+
tie_word_embeddings=self.tie_word_embeddings,
|
| 228 |
+
use_place_embeddings=self.use_place_embeddings,
|
| 229 |
+
use_role_embeddings=self.use_role_embeddings,
|
| 230 |
+
place_vocab_size=self.place_vocab_size,
|
| 231 |
+
role_vocab_size=self.role_vocab_size,
|
| 232 |
+
labels_are_shifted=True,
|
| 233 |
+
)
|
configuration_gpt.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Exports for the GPT model configuration.
|
| 2 |
+
|
| 3 |
+
New code should import these from :mod:`GPT.config`.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from .config import (
|
| 7 |
+
DEFAULT_BLOCK_SIZE,
|
| 8 |
+
DEFAULT_HEAD_DIM,
|
| 9 |
+
DEFAULT_HIDDEN_SIZE,
|
| 10 |
+
DEFAULT_INTERMEDIATE_SIZE,
|
| 11 |
+
DEFAULT_NUM_ATTENTION_HEADS,
|
| 12 |
+
DEFAULT_NUM_HIDDEN_LAYERS,
|
| 13 |
+
DEFAULT_NUM_KEY_VALUE_HEADS,
|
| 14 |
+
DEFAULT_PLACE_VOCAB_SIZE,
|
| 15 |
+
DEFAULT_ROPE_THETA,
|
| 16 |
+
DEFAULT_ROLE_VOCAB_SIZE,
|
| 17 |
+
DEFAULT_VOCAB_SIZE,
|
| 18 |
+
GPTConfig,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
__all__ = [
|
| 23 |
+
"DEFAULT_BLOCK_SIZE",
|
| 24 |
+
"DEFAULT_HEAD_DIM",
|
| 25 |
+
"DEFAULT_HIDDEN_SIZE",
|
| 26 |
+
"DEFAULT_INTERMEDIATE_SIZE",
|
| 27 |
+
"DEFAULT_NUM_ATTENTION_HEADS",
|
| 28 |
+
"DEFAULT_NUM_HIDDEN_LAYERS",
|
| 29 |
+
"DEFAULT_NUM_KEY_VALUE_HEADS",
|
| 30 |
+
"DEFAULT_PLACE_VOCAB_SIZE",
|
| 31 |
+
"DEFAULT_ROPE_THETA",
|
| 32 |
+
"DEFAULT_ROLE_VOCAB_SIZE",
|
| 33 |
+
"DEFAULT_VOCAB_SIZE",
|
| 34 |
+
"GPTConfig",
|
| 35 |
+
]
|
lm_eval_fusion
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
"""Run lm-eval with the local Atom2.7m model registered."""
|
| 3 |
+
|
| 4 |
+
import lm_eval_fusion # noqa: F401
|
| 5 |
+
from lm_eval.__main__ import cli_evaluate
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
if __name__ == "__main__":
|
| 9 |
+
cli_evaluate()
|
lm_eval_fusion.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""lm-eval wrapper for Atom2.7m checkpoints.
|
| 2 |
+
|
| 3 |
+
The standard ``hf`` lm-eval model does not use the fusion tokenizer wrapper and
|
| 4 |
+
does not pass arithmetic feature streams. This model keeps lm-eval's
|
| 5 |
+
log-likelihood interface while encoding with ``tokenizer_utils.load_tokenizer``
|
| 6 |
+
and forwarding ``place_ids`` and ``role_ids``.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from contextlib import nullcontext
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Any
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from lm_eval.api.model import LM
|
| 18 |
+
from lm_eval.api.registry import register_model
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
from transformers import AutoModelForCausalLM
|
| 21 |
+
|
| 22 |
+
from tokenizer_utils import EOT_ID, FusionTokenizer, load_tokenizer
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _parse_bool(value: Any, default: bool = False) -> bool:
|
| 26 |
+
if value is None:
|
| 27 |
+
return default
|
| 28 |
+
if isinstance(value, bool):
|
| 29 |
+
return value
|
| 30 |
+
return str(value).strip().lower() in {"1", "true", "yes", "on"}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _parse_batch_size(value: int | str | None, max_batch_size: int | None) -> int:
|
| 34 |
+
if value is None:
|
| 35 |
+
return 1
|
| 36 |
+
if isinstance(value, int):
|
| 37 |
+
return value
|
| 38 |
+
text = str(value).strip().lower()
|
| 39 |
+
if text == "auto" or text.startswith("auto:"):
|
| 40 |
+
return int(max_batch_size or 64)
|
| 41 |
+
return int(text)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _dtype_from_name(value: str | torch.dtype | None) -> torch.dtype | None:
|
| 45 |
+
if value is None or value == "auto":
|
| 46 |
+
return None
|
| 47 |
+
if isinstance(value, torch.dtype):
|
| 48 |
+
return value
|
| 49 |
+
normalized = str(value).replace("torch.", "").lower()
|
| 50 |
+
if normalized in {"bf16", "bfloat16"}:
|
| 51 |
+
return torch.bfloat16
|
| 52 |
+
if normalized in {"fp16", "float16", "half"}:
|
| 53 |
+
return torch.float16
|
| 54 |
+
if normalized in {"fp32", "float32", "float"}:
|
| 55 |
+
return torch.float32
|
| 56 |
+
raise ValueError(f"Unsupported dtype: {value!r}")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@register_model("atom2.7m")
|
| 60 |
+
class FusionGPTLM(LM):
|
| 61 |
+
"""Fusion-tokenizer GPT adapter for lm-eval log-likelihood tasks."""
|
| 62 |
+
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
pretrained: str = "outputs/fusion_run/final_model",
|
| 66 |
+
tokenizer_dir: str = "tokenizer_4k",
|
| 67 |
+
batch_size: int | str | None = 1,
|
| 68 |
+
max_batch_size: int | None = 64,
|
| 69 |
+
max_length: int | None = None,
|
| 70 |
+
device: str | None = "cuda",
|
| 71 |
+
dtype: str | torch.dtype | None = "auto",
|
| 72 |
+
mixed_precision_dtype: str | torch.dtype | None = "auto",
|
| 73 |
+
trust_remote_code: bool | str | None = None,
|
| 74 |
+
**_: Any,
|
| 75 |
+
) -> None:
|
| 76 |
+
super().__init__()
|
| 77 |
+
del trust_remote_code
|
| 78 |
+
if device is None or device == "auto":
|
| 79 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 80 |
+
self._device = torch.device(device)
|
| 81 |
+
self.batch_size = _parse_batch_size(batch_size, max_batch_size)
|
| 82 |
+
self.tokenizer: FusionTokenizer = load_tokenizer(Path(tokenizer_dir))
|
| 83 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 84 |
+
Path(pretrained),
|
| 85 |
+
trust_remote_code=True,
|
| 86 |
+
).to(self.device)
|
| 87 |
+
model_dtype = _dtype_from_name(dtype)
|
| 88 |
+
if model_dtype is not None:
|
| 89 |
+
self.model = self.model.to(dtype=model_dtype)
|
| 90 |
+
if mixed_precision_dtype == "auto":
|
| 91 |
+
self.mixed_precision_dtype = (
|
| 92 |
+
torch.bfloat16 if self.device.type == "cuda" else None
|
| 93 |
+
)
|
| 94 |
+
else:
|
| 95 |
+
self.mixed_precision_dtype = _dtype_from_name(mixed_precision_dtype)
|
| 96 |
+
self.model.eval()
|
| 97 |
+
self.max_length = int(
|
| 98 |
+
max_length
|
| 99 |
+
or getattr(self.model.config, "block_size", None)
|
| 100 |
+
or getattr(self.model.config, "max_position_embeddings", 512)
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def eot_token_id(self) -> int:
|
| 105 |
+
return EOT_ID
|
| 106 |
+
|
| 107 |
+
def tok_encode(
|
| 108 |
+
self,
|
| 109 |
+
string: str,
|
| 110 |
+
add_special_tokens: bool | None = None,
|
| 111 |
+
left_truncate_len: int | None = None,
|
| 112 |
+
**_: Any,
|
| 113 |
+
) -> list[int]:
|
| 114 |
+
del add_special_tokens
|
| 115 |
+
ids = self.tokenizer.encode(string).input_ids
|
| 116 |
+
if left_truncate_len is not None:
|
| 117 |
+
ids = ids[-left_truncate_len:]
|
| 118 |
+
return ids
|
| 119 |
+
|
| 120 |
+
def tok_decode(self, tokens, skip_special_tokens: bool = True) -> str:
|
| 121 |
+
if isinstance(tokens, int):
|
| 122 |
+
tokens = [tokens]
|
| 123 |
+
return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
| 124 |
+
|
| 125 |
+
def _encode_request(
|
| 126 |
+
self,
|
| 127 |
+
context: str,
|
| 128 |
+
continuation: str,
|
| 129 |
+
) -> tuple[list[int], list[int], list[int], list[int], int]:
|
| 130 |
+
if context == "":
|
| 131 |
+
continuation_encoding = self.tokenizer.encode(continuation)
|
| 132 |
+
ids = [self.eot_token_id] + continuation_encoding.input_ids
|
| 133 |
+
place_ids = [0] + continuation_encoding.place_ids
|
| 134 |
+
role_ids = [0] + continuation_encoding.role_ids
|
| 135 |
+
context_len = 1
|
| 136 |
+
continuation_ids = continuation_encoding.input_ids
|
| 137 |
+
else:
|
| 138 |
+
n_spaces = len(context) - len(context.rstrip())
|
| 139 |
+
if n_spaces > 0:
|
| 140 |
+
continuation = context[-n_spaces:] + continuation
|
| 141 |
+
context = context[:-n_spaces]
|
| 142 |
+
full_encoding = self.tokenizer.encode(context + continuation)
|
| 143 |
+
context_encoding = self.tokenizer.encode(context)
|
| 144 |
+
ids = full_encoding.input_ids
|
| 145 |
+
place_ids = full_encoding.place_ids
|
| 146 |
+
role_ids = full_encoding.role_ids
|
| 147 |
+
context_len = len(context_encoding.input_ids)
|
| 148 |
+
continuation_ids = ids[context_len:]
|
| 149 |
+
|
| 150 |
+
if not continuation_ids:
|
| 151 |
+
raise ValueError("Continuation encoded to zero tokens")
|
| 152 |
+
return ids, place_ids, role_ids, continuation_ids, context_len
|
| 153 |
+
|
| 154 |
+
def loglikelihood(
|
| 155 |
+
self,
|
| 156 |
+
requests: list["Instance"],
|
| 157 |
+
disable_tqdm: bool = False,
|
| 158 |
+
) -> list[tuple[float, bool]]:
|
| 159 |
+
encoded = [
|
| 160 |
+
self._encode_request(context, continuation)
|
| 161 |
+
for context, continuation in tqdm(
|
| 162 |
+
[req.args for req in requests],
|
| 163 |
+
desc="Fusion tokenizing inputs",
|
| 164 |
+
disable=disable_tqdm,
|
| 165 |
+
)
|
| 166 |
+
]
|
| 167 |
+
results: list[tuple[float, bool]] = []
|
| 168 |
+
for start in tqdm(
|
| 169 |
+
range(0, len(encoded), self.batch_size),
|
| 170 |
+
desc="Running fusion loglikelihood requests",
|
| 171 |
+
disable=disable_tqdm or self.rank != 0,
|
| 172 |
+
):
|
| 173 |
+
batch = encoded[start : start + self.batch_size]
|
| 174 |
+
rows = []
|
| 175 |
+
row_places = []
|
| 176 |
+
row_roles = []
|
| 177 |
+
row_targets = []
|
| 178 |
+
row_score_slices = []
|
| 179 |
+
for ids, place_ids, role_ids, continuation_ids, context_len in batch:
|
| 180 |
+
window_start = max(0, len(ids) - (self.max_length + 1))
|
| 181 |
+
window_ids = ids[window_start:]
|
| 182 |
+
window_places = place_ids[window_start:]
|
| 183 |
+
window_roles = role_ids[window_start:]
|
| 184 |
+
input_ids = window_ids[:-1]
|
| 185 |
+
targets = window_ids[1:]
|
| 186 |
+
full_score_start = context_len - 1
|
| 187 |
+
full_score_end = len(ids) - 1
|
| 188 |
+
score_start = max(full_score_start, window_start) - window_start
|
| 189 |
+
score_end = full_score_end - window_start
|
| 190 |
+
if score_end <= score_start:
|
| 191 |
+
raise ValueError("No continuation tokens remain after truncation")
|
| 192 |
+
scored_continuation_ids = continuation_ids[-(score_end - score_start) :]
|
| 193 |
+
rows.append(input_ids)
|
| 194 |
+
row_places.append(window_places[:-1])
|
| 195 |
+
row_roles.append(window_roles[:-1])
|
| 196 |
+
row_targets.append(targets)
|
| 197 |
+
row_score_slices.append((score_start, score_end, scored_continuation_ids))
|
| 198 |
+
|
| 199 |
+
max_len = max(len(row) for row in rows)
|
| 200 |
+
input_tensor = torch.full(
|
| 201 |
+
(len(rows), max_len),
|
| 202 |
+
self.eot_token_id,
|
| 203 |
+
dtype=torch.long,
|
| 204 |
+
device=self.device,
|
| 205 |
+
)
|
| 206 |
+
place_tensor = torch.zeros_like(input_tensor)
|
| 207 |
+
role_tensor = torch.zeros_like(input_tensor)
|
| 208 |
+
attention_mask = torch.zeros_like(input_tensor, dtype=torch.bool)
|
| 209 |
+
target_tensor = torch.full_like(input_tensor, self.eot_token_id)
|
| 210 |
+
for row, (ids, places, roles, targets) in enumerate(
|
| 211 |
+
zip(rows, row_places, row_roles, row_targets, strict=True)
|
| 212 |
+
):
|
| 213 |
+
length = len(ids)
|
| 214 |
+
input_tensor[row, :length] = torch.tensor(ids, device=self.device)
|
| 215 |
+
place_tensor[row, :length] = torch.tensor(places, device=self.device)
|
| 216 |
+
role_tensor[row, :length] = torch.tensor(roles, device=self.device)
|
| 217 |
+
target_tensor[row, :length] = torch.tensor(targets, device=self.device)
|
| 218 |
+
attention_mask[row, :length] = True
|
| 219 |
+
|
| 220 |
+
autocast = (
|
| 221 |
+
torch.autocast(
|
| 222 |
+
device_type=self.device.type,
|
| 223 |
+
dtype=self.mixed_precision_dtype,
|
| 224 |
+
enabled=self.mixed_precision_dtype is not None,
|
| 225 |
+
)
|
| 226 |
+
if self.device.type == "cuda"
|
| 227 |
+
else nullcontext()
|
| 228 |
+
)
|
| 229 |
+
with torch.inference_mode(), autocast:
|
| 230 |
+
logits = self.model(
|
| 231 |
+
input_ids=input_tensor,
|
| 232 |
+
place_ids=place_tensor,
|
| 233 |
+
role_ids=role_tensor,
|
| 234 |
+
attention_mask=attention_mask,
|
| 235 |
+
).logits
|
| 236 |
+
log_probs = F.log_softmax(logits.float(), dim=-1)
|
| 237 |
+
|
| 238 |
+
for row, (score_start, score_end, continuation_ids) in enumerate(row_score_slices):
|
| 239 |
+
row_log_probs = log_probs[row, score_start:score_end]
|
| 240 |
+
row_targets_for_score = target_tensor[row, score_start:score_end]
|
| 241 |
+
token_log_probs = torch.gather(
|
| 242 |
+
row_log_probs,
|
| 243 |
+
1,
|
| 244 |
+
row_targets_for_score.unsqueeze(-1),
|
| 245 |
+
).squeeze(-1)
|
| 246 |
+
greedy = torch.equal(
|
| 247 |
+
row_log_probs.argmax(dim=-1),
|
| 248 |
+
torch.tensor(continuation_ids, dtype=torch.long, device=self.device),
|
| 249 |
+
)
|
| 250 |
+
results.append((float(token_log_probs.sum().item()), bool(greedy)))
|
| 251 |
+
|
| 252 |
+
return results
|
| 253 |
+
|
| 254 |
+
def loglikelihood_rolling(
|
| 255 |
+
self,
|
| 256 |
+
requests: list["Instance"],
|
| 257 |
+
disable_tqdm: bool = False,
|
| 258 |
+
) -> list[float]:
|
| 259 |
+
results = []
|
| 260 |
+
for (text,) in tqdm(
|
| 261 |
+
[req.args for req in requests],
|
| 262 |
+
desc="Running fusion rolling loglikelihood",
|
| 263 |
+
disable=disable_tqdm or self.rank != 0,
|
| 264 |
+
):
|
| 265 |
+
encoding = self.tokenizer.encode(text)
|
| 266 |
+
ids = encoding.input_ids
|
| 267 |
+
places = encoding.place_ids
|
| 268 |
+
roles = encoding.role_ids
|
| 269 |
+
total = 0.0
|
| 270 |
+
start = 0
|
| 271 |
+
while start < len(ids):
|
| 272 |
+
end = min(len(ids), start + self.max_length)
|
| 273 |
+
prefix = [self.eot_token_id] if start == 0 else ids[start - 1 : start]
|
| 274 |
+
chunk_ids = prefix + ids[start:end]
|
| 275 |
+
chunk_places = [0] + places[start:end] if start == 0 else places[start - 1 : end]
|
| 276 |
+
chunk_roles = [0] + roles[start:end] if start == 0 else roles[start - 1 : end]
|
| 277 |
+
input_ids = torch.tensor([chunk_ids[:-1]], dtype=torch.long, device=self.device)
|
| 278 |
+
place_ids = torch.tensor([chunk_places[:-1]], dtype=torch.long, device=self.device)
|
| 279 |
+
role_ids = torch.tensor([chunk_roles[:-1]], dtype=torch.long, device=self.device)
|
| 280 |
+
targets = torch.tensor(chunk_ids[1:], dtype=torch.long, device=self.device)
|
| 281 |
+
with torch.inference_mode():
|
| 282 |
+
logits = self.model(
|
| 283 |
+
input_ids=input_ids,
|
| 284 |
+
place_ids=place_ids,
|
| 285 |
+
role_ids=role_ids,
|
| 286 |
+
).logits[0]
|
| 287 |
+
log_probs = F.log_softmax(logits.float(), dim=-1)
|
| 288 |
+
total += float(
|
| 289 |
+
torch.gather(log_probs, 1, targets.unsqueeze(-1)).sum().item()
|
| 290 |
+
)
|
| 291 |
+
start = end
|
| 292 |
+
results.append(total)
|
| 293 |
+
return results
|
| 294 |
+
|
| 295 |
+
def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
|
| 296 |
+
raise NotImplementedError(
|
| 297 |
+
"FusionGPTLM currently supports loglikelihood tasks. "
|
| 298 |
+
"Use tasks with multiple-choice/loglikelihood output."
|
| 299 |
+
)
|
model.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
from transformers import PreTrainedModel
|
| 9 |
+
from transformers.cache_utils import DynamicCache
|
| 10 |
+
from transformers.generation.utils import GenerationMixin
|
| 11 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 12 |
+
|
| 13 |
+
from .config import GPTConfig
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
CONTROL_TENSOR_NAME_PATTERNS = (
|
| 17 |
+
"scale",
|
| 18 |
+
"gate",
|
| 19 |
+
"gain",
|
| 20 |
+
"norm",
|
| 21 |
+
"ln_",
|
| 22 |
+
"rms",
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class CastedLinear(nn.Linear):
|
| 27 |
+
"""Store linear params in FP32, cast to activation dtype for matmul."""
|
| 28 |
+
|
| 29 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 30 |
+
weight = self.weight.to(dtype=x.dtype)
|
| 31 |
+
bias = self.bias.to(dtype=x.dtype) if self.bias is not None else None
|
| 32 |
+
return F.linear(x, weight, bias)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def restore_fp32_params(model: nn.Module) -> None:
|
| 36 |
+
"""Keep linear weights and control params in FP32 after dtype conversion."""
|
| 37 |
+
for module in model.modules():
|
| 38 |
+
if isinstance(module, CastedLinear):
|
| 39 |
+
module.float()
|
| 40 |
+
for name, param in model.named_parameters():
|
| 41 |
+
if (
|
| 42 |
+
param.ndim < 2
|
| 43 |
+
or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)
|
| 44 |
+
) and param.dtype != torch.float32:
|
| 45 |
+
param.data = param.data.float()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class RMSNorm(nn.Module):
|
| 49 |
+
def __init__(self, dim, eps=1e-6):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.eps = eps
|
| 52 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
|
| 56 |
+
return (x.float() * rms).to(dtype=x.dtype) * self.weight.to(dtype=x.dtype)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def build_rope_inv_freq(head_dim, theta=2500.0):
|
| 60 |
+
return 1.0 / (theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def precompute_rope_cos_sin(head_dim, seq_len, theta=2500.0):
|
| 64 |
+
freqs = build_rope_inv_freq(head_dim, theta)
|
| 65 |
+
t = torch.arange(seq_len, dtype=torch.float32)
|
| 66 |
+
freqs = torch.outer(t, freqs)
|
| 67 |
+
return freqs.cos(), freqs.sin()
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _apply_rope(x, cos, sin):
|
| 71 |
+
x_float = x.float()
|
| 72 |
+
x_pair = x_float.reshape(*x_float.shape[:-1], -1, 2)
|
| 73 |
+
even = x_pair[..., 0]
|
| 74 |
+
odd = x_pair[..., 1]
|
| 75 |
+
cos = cos.unsqueeze(0).unsqueeze(0)
|
| 76 |
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 77 |
+
x_rot = torch.stack((even * cos - odd * sin, even * sin + odd * cos), dim=-1)
|
| 78 |
+
return x_rot.flatten(-2).type_as(x)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def apply_rotary_emb(q, k, freqs_cis):
|
| 82 |
+
cos, sin = freqs_cis
|
| 83 |
+
return _apply_rope(q, cos, sin), _apply_rope(k, cos, sin)
|
| 84 |
+
|
| 85 |
+
class GPTAttention(nn.Module):
|
| 86 |
+
def __init__(self, config, layer_idx):
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.layer_idx = layer_idx
|
| 89 |
+
self.n_head = config.num_attention_heads
|
| 90 |
+
self.n_kv_heads = config.num_key_value_heads
|
| 91 |
+
self.head_dim = config.head_dim
|
| 92 |
+
self.n_rep = self.n_head // self.n_kv_heads
|
| 93 |
+
self.xsa_projection = config.xsa_projection
|
| 94 |
+
|
| 95 |
+
self.q_proj = CastedLinear(config.hidden_size, self.n_head * self.head_dim, bias=False)
|
| 96 |
+
self.k_proj = CastedLinear(config.hidden_size, self.n_kv_heads * self.head_dim, bias=False)
|
| 97 |
+
self.v_proj = CastedLinear(config.hidden_size, self.n_kv_heads * self.head_dim, bias=False)
|
| 98 |
+
self.o_proj = CastedLinear(self.n_head * self.head_dim, config.hidden_size, bias=False)
|
| 99 |
+
|
| 100 |
+
def _xsa_efficient(self, y: Tensor, v_current: Tensor) -> Tensor:
|
| 101 |
+
B, H, T, D = y.shape
|
| 102 |
+
Hkv = v_current.size(1)
|
| 103 |
+
group = H // Hkv
|
| 104 |
+
|
| 105 |
+
y_g = y.reshape(B, Hkv, group, T, D)
|
| 106 |
+
v_n = F.normalize(v_current, dim=-1).unsqueeze(2)
|
| 107 |
+
|
| 108 |
+
proj = (y_g * v_n).sum(dim=-1, keepdim=True) * v_n
|
| 109 |
+
return (y_g - proj).reshape(B, H, T, D)
|
| 110 |
+
|
| 111 |
+
def forward(self, x, freqs_cis, past_key_value=None, use_cache=False, attention_mask=None):
|
| 112 |
+
B, T, _ = x.size()
|
| 113 |
+
|
| 114 |
+
q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 115 |
+
k_current = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
|
| 116 |
+
v_current = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
|
| 117 |
+
|
| 118 |
+
q, k_current = apply_rotary_emb(q, k_current, freqs_cis)
|
| 119 |
+
|
| 120 |
+
if past_key_value is not None:
|
| 121 |
+
k, v = past_key_value.update(k_current, v_current, self.layer_idx)
|
| 122 |
+
else:
|
| 123 |
+
k, v = k_current, v_current
|
| 124 |
+
|
| 125 |
+
S = k.size(2)
|
| 126 |
+
|
| 127 |
+
is_causal = past_key_value is None or past_key_value.get_seq_length(self.layer_idx) == T
|
| 128 |
+
|
| 129 |
+
attn_mask = None
|
| 130 |
+
if attention_mask is not None:
|
| 131 |
+
key_pad = attention_mask.to(torch.bool)[:, None, None, :]
|
| 132 |
+
|
| 133 |
+
if is_causal and T > 1:
|
| 134 |
+
causal = torch.ones(T, S, dtype=torch.bool, device=x.device).tril(diagonal=S - T)
|
| 135 |
+
attn_mask = key_pad & causal[None, None, :, :]
|
| 136 |
+
else:
|
| 137 |
+
attn_mask = key_pad.expand(B, 1, T, S)
|
| 138 |
+
|
| 139 |
+
is_causal = False
|
| 140 |
+
|
| 141 |
+
y = F.scaled_dot_product_attention(
|
| 142 |
+
q,
|
| 143 |
+
k,
|
| 144 |
+
v,
|
| 145 |
+
attn_mask=attn_mask,
|
| 146 |
+
is_causal=is_causal,
|
| 147 |
+
enable_gqa=(self.n_kv_heads != self.n_head),
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
if self.xsa_projection:
|
| 151 |
+
y = self._xsa_efficient(y, v_current)
|
| 152 |
+
|
| 153 |
+
y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.head_dim)
|
| 154 |
+
return self.o_proj(y)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class GPTMLP(nn.Module):
|
| 158 |
+
def __init__(self, config):
|
| 159 |
+
super().__init__()
|
| 160 |
+
self.w_gate = CastedLinear(config.hidden_size, config.intermediate_size, bias=False)
|
| 161 |
+
self.w_up = CastedLinear(config.hidden_size, config.intermediate_size, bias=False)
|
| 162 |
+
self.w_down = CastedLinear(config.intermediate_size, config.hidden_size, bias=False)
|
| 163 |
+
|
| 164 |
+
def forward(self, x):
|
| 165 |
+
return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class GPTBlock(nn.Module):
|
| 169 |
+
def __init__(self, config, layer_idx):
|
| 170 |
+
super().__init__()
|
| 171 |
+
self.ln_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 172 |
+
self.attn = GPTAttention(config, layer_idx)
|
| 173 |
+
self.ln_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 174 |
+
self.mlp = GPTMLP(config)
|
| 175 |
+
|
| 176 |
+
def forward(self, x, freqs_cis, past_key_value=None, use_cache=False, attention_mask=None):
|
| 177 |
+
x = x + self.attn(self.ln_1(x), freqs_cis, past_key_value, use_cache, attention_mask=attention_mask)
|
| 178 |
+
x = x + self.mlp(self.ln_2(x))
|
| 179 |
+
return x
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class GPTPreTrainedModel(PreTrainedModel):
|
| 183 |
+
config_class = GPTConfig
|
| 184 |
+
base_model_prefix = "transformer"
|
| 185 |
+
supports_gradient_checkpointing = False
|
| 186 |
+
|
| 187 |
+
def _init_weights(self, module):
|
| 188 |
+
std = self.config.hidden_size ** -0.5
|
| 189 |
+
if isinstance(module, nn.Linear):
|
| 190 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 191 |
+
elif isinstance(module, nn.Embedding):
|
| 192 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class GPTForCausalLM(GPTPreTrainedModel, GenerationMixin):
|
| 196 |
+
_tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"}
|
| 197 |
+
|
| 198 |
+
def __init__(self, config):
|
| 199 |
+
super().__init__(config)
|
| 200 |
+
self.config = config
|
| 201 |
+
self.transformer = nn.ModuleDict(dict(
|
| 202 |
+
wte=nn.Embedding(config.vocab_size, config.hidden_size),
|
| 203 |
+
h=nn.ModuleList([GPTBlock(config, i) for i in range(config.num_hidden_layers)]),
|
| 204 |
+
ln_f=RMSNorm(config.hidden_size, eps=config.rms_norm_eps),
|
| 205 |
+
))
|
| 206 |
+
self.lm_head = CastedLinear(config.hidden_size, config.vocab_size, bias=False)
|
| 207 |
+
if config.tie_word_embeddings:
|
| 208 |
+
self.lm_head.weight = self.transformer["wte"].weight
|
| 209 |
+
if getattr(config, "use_place_embeddings", True):
|
| 210 |
+
self.place_embeddings = nn.Embedding(
|
| 211 |
+
config.place_vocab_size,
|
| 212 |
+
config.hidden_size,
|
| 213 |
+
padding_idx=0,
|
| 214 |
+
)
|
| 215 |
+
else:
|
| 216 |
+
self.place_embeddings = None
|
| 217 |
+
if getattr(config, "use_role_embeddings", True):
|
| 218 |
+
self.role_embeddings = nn.Embedding(
|
| 219 |
+
config.role_vocab_size,
|
| 220 |
+
config.hidden_size,
|
| 221 |
+
padding_idx=0,
|
| 222 |
+
)
|
| 223 |
+
else:
|
| 224 |
+
self.role_embeddings = None
|
| 225 |
+
self._freqs_cis_cache = None
|
| 226 |
+
self.post_init()
|
| 227 |
+
with torch.no_grad():
|
| 228 |
+
if self.place_embeddings is not None:
|
| 229 |
+
self.place_embeddings.weight[0].zero_()
|
| 230 |
+
if self.role_embeddings is not None:
|
| 231 |
+
self.role_embeddings.weight[0].zero_()
|
| 232 |
+
restore_fp32_params(self)
|
| 233 |
+
|
| 234 |
+
def _apply(self, fn):
|
| 235 |
+
module = super()._apply(fn)
|
| 236 |
+
restore_fp32_params(self)
|
| 237 |
+
return module
|
| 238 |
+
|
| 239 |
+
def get_input_embeddings(self):
|
| 240 |
+
return self.transformer["wte"]
|
| 241 |
+
|
| 242 |
+
def set_input_embeddings(self, value):
|
| 243 |
+
self.transformer["wte"] = value
|
| 244 |
+
|
| 245 |
+
def get_output_embeddings(self):
|
| 246 |
+
return self.lm_head
|
| 247 |
+
|
| 248 |
+
def set_output_embeddings(self, new_embeddings):
|
| 249 |
+
self.lm_head = new_embeddings
|
| 250 |
+
|
| 251 |
+
def embed_tokens(self, input_ids, *, place_ids=None, role_ids=None, **kwargs):
|
| 252 |
+
embeddings = self.transformer["wte"](input_ids)
|
| 253 |
+
if self.place_embeddings is not None:
|
| 254 |
+
if place_ids is None:
|
| 255 |
+
place_ids = torch.zeros_like(input_ids)
|
| 256 |
+
if place_ids.shape != input_ids.shape:
|
| 257 |
+
raise ValueError("place_ids must match input_ids shape")
|
| 258 |
+
embeddings = embeddings + self.place_embeddings(place_ids)
|
| 259 |
+
if self.role_embeddings is not None:
|
| 260 |
+
if role_ids is None:
|
| 261 |
+
role_ids = torch.zeros_like(input_ids)
|
| 262 |
+
if role_ids.shape != input_ids.shape:
|
| 263 |
+
raise ValueError("role_ids must match input_ids shape")
|
| 264 |
+
embeddings = embeddings + self.role_embeddings(role_ids)
|
| 265 |
+
return embeddings
|
| 266 |
+
|
| 267 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs):
|
| 268 |
+
if past_key_values is not None and past_key_values.get_seq_length() > 0:
|
| 269 |
+
input_ids = input_ids[:, -1:]
|
| 270 |
+
if kwargs.get("place_ids") is not None:
|
| 271 |
+
kwargs["place_ids"] = kwargs["place_ids"][:, -1:]
|
| 272 |
+
if kwargs.get("role_ids") is not None:
|
| 273 |
+
kwargs["role_ids"] = kwargs["role_ids"][:, -1:]
|
| 274 |
+
return {
|
| 275 |
+
"input_ids": input_ids,
|
| 276 |
+
"place_ids": kwargs.get("place_ids"),
|
| 277 |
+
"role_ids": kwargs.get("role_ids"),
|
| 278 |
+
"attention_mask": attention_mask,
|
| 279 |
+
"past_key_values": past_key_values,
|
| 280 |
+
"use_cache": True,
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
def _get_freqs_cis(self, seq_len, device):
|
| 284 |
+
cache = self._freqs_cis_cache
|
| 285 |
+
if cache is None or cache[0].device != device or cache[0].size(0) < seq_len:
|
| 286 |
+
cache = tuple(
|
| 287 |
+
tensor.to(device)
|
| 288 |
+
for tensor in precompute_rope_cos_sin(self.config.head_dim, seq_len, self.config.rope_theta)
|
| 289 |
+
)
|
| 290 |
+
if torch.is_inference_mode_enabled():
|
| 291 |
+
return cache[0][:seq_len], cache[1][:seq_len]
|
| 292 |
+
self._freqs_cis_cache = cache
|
| 293 |
+
return cache[0][:seq_len], cache[1][:seq_len]
|
| 294 |
+
|
| 295 |
+
def forward(
|
| 296 |
+
self,
|
| 297 |
+
input_ids,
|
| 298 |
+
attention_mask=None,
|
| 299 |
+
labels=None,
|
| 300 |
+
past_key_values: Optional[DynamicCache] = None,
|
| 301 |
+
use_cache=False,
|
| 302 |
+
**kwargs,
|
| 303 |
+
):
|
| 304 |
+
B, T = input_ids.size()
|
| 305 |
+
if use_cache and past_key_values is None:
|
| 306 |
+
past_key_values = DynamicCache()
|
| 307 |
+
|
| 308 |
+
past_len = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 309 |
+
x = self.embed_tokens(input_ids, **kwargs)
|
| 310 |
+
cos, sin = self._get_freqs_cis(past_len + T, input_ids.device)
|
| 311 |
+
freqs_cis = (
|
| 312 |
+
cos[past_len:past_len + T],
|
| 313 |
+
sin[past_len:past_len + T],
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
for block in self.transformer["h"]:
|
| 317 |
+
x = block(x, freqs_cis, past_key_values if use_cache else None, use_cache, attention_mask=attention_mask)
|
| 318 |
+
|
| 319 |
+
x = self.transformer["ln_f"](x)
|
| 320 |
+
logits = self.lm_head(x)
|
| 321 |
+
|
| 322 |
+
loss = None
|
| 323 |
+
if labels is not None:
|
| 324 |
+
if getattr(self.config, "labels_are_shifted", False):
|
| 325 |
+
loss = F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), labels.reshape(-1))
|
| 326 |
+
else:
|
| 327 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 328 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 329 |
+
loss = F.cross_entropy(shift_logits.float().view(-1, shift_logits.size(-1)), shift_labels.reshape(-1))
|
| 330 |
+
|
| 331 |
+
return CausalLMOutputWithPast(
|
| 332 |
+
loss=loss,
|
| 333 |
+
logits=logits,
|
| 334 |
+
past_key_values=past_key_values if use_cache else None,
|
| 335 |
+
)
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a21f5910c898c5464320d3f01b15ccde1eb278073266b221e48e9ec15ccbe899
|
| 3 |
+
size 10930496
|
pretraining_curriculum.json
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": 1,
|
| 3 |
+
"transition_fraction": 0.1,
|
| 4 |
+
"stages": [
|
| 5 |
+
{
|
| 6 |
+
"name": "early",
|
| 7 |
+
"start": 0.0,
|
| 8 |
+
"end": 0.4,
|
| 9 |
+
"weights": {
|
| 10 |
+
"Ultra-FineWeb": 0.5,
|
| 11 |
+
"FineWeb-Edu": 0.38,
|
| 12 |
+
"FineMath": 0.05,
|
| 13 |
+
"Cosmopedia-v2": 0.05,
|
| 14 |
+
"UltraData-Math-L2-preview": 0.02
|
| 15 |
+
}
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"name": "mid",
|
| 19 |
+
"start": 0.4,
|
| 20 |
+
"end": 0.8,
|
| 21 |
+
"weights": {
|
| 22 |
+
"Ultra-FineWeb": 0.12,
|
| 23 |
+
"FineWeb-Edu": 0.22,
|
| 24 |
+
"FineMath": 0.18,
|
| 25 |
+
"Cosmopedia-v2": 0.13,
|
| 26 |
+
"UltraData-Math-L2-preview": 0.12,
|
| 27 |
+
"Ultra-FineWeb-L3-en-QA-Synthetic": 0.05,
|
| 28 |
+
"Synthetic-Arithmetic": 0.18
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
{
|
| 32 |
+
"name": "late",
|
| 33 |
+
"start": 0.8,
|
| 34 |
+
"end": 1.0,
|
| 35 |
+
"weights": {
|
| 36 |
+
"Ultra-FineWeb": 0.105,
|
| 37 |
+
"FineWeb-Edu": 0.21,
|
| 38 |
+
"FineMath": 0.14,
|
| 39 |
+
"Cosmopedia-v2": 0.14,
|
| 40 |
+
"UltraData-Math-L2-preview": 0.105,
|
| 41 |
+
"Ultra-FineWeb-L3-en-QA-Synthetic": 0.2,
|
| 42 |
+
"Synthetic-Arithmetic": 0.1
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
]
|
| 46 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
transformers
|
| 3 |
+
tokenizers
|
| 4 |
+
safetensors
|
| 5 |
+
tqdm
|
| 6 |
+
lm-eval
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<|endoftext|>"
|
| 4 |
+
],
|
| 5 |
+
"bos_token": "<|bos|>",
|
| 6 |
+
"eos_token": "<|eos|>",
|
| 7 |
+
"pad_token": "<|pad|>",
|
| 8 |
+
"unk_token": "<|unk|>"
|
| 9 |
+
}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<|endoftext|>"
|
| 4 |
+
],
|
| 5 |
+
"bos_token": "<|bos|>",
|
| 6 |
+
"eos_token": "<|eos|>",
|
| 7 |
+
"model_max_length": 512,
|
| 8 |
+
"pad_token": "<|pad|>",
|
| 9 |
+
"tokenizer_class": "GPT2TokenizerFast",
|
| 10 |
+
"unk_token": "<|unk|>"
|
| 11 |
+
}
|
tokenizer_utils.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared construction and loading helpers for the project's tokenizer."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import re
|
| 9 |
+
from typing import Any, Iterable
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
SPECIAL_TOKENS = [
|
| 13 |
+
"<|pad|>",
|
| 14 |
+
"<|bos|>",
|
| 15 |
+
"<|eos|>",
|
| 16 |
+
"<|unk|>",
|
| 17 |
+
"<|endoftext|>",
|
| 18 |
+
]
|
| 19 |
+
EOT_ID = SPECIAL_TOKENS.index("<|endoftext|>")
|
| 20 |
+
ARITHMETIC_TOKENS = ("+", "-", "*", "/", "=", "(", ")")
|
| 21 |
+
MAX_PLACE_ID = 64
|
| 22 |
+
PLACE_OVERFLOW_ID = MAX_PLACE_ID + 1
|
| 23 |
+
PLACE_VOCAB_SIZE = PLACE_OVERFLOW_ID + 1
|
| 24 |
+
RESULT_ROLE_ID = 10
|
| 25 |
+
SPACE_ROLE_ID = 11
|
| 26 |
+
ROLE_VOCAB_SIZE = SPACE_ROLE_ID + 1
|
| 27 |
+
MAX_OPERAND_ROLES = 9
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass(frozen=True)
|
| 31 |
+
class FusionEncoding:
|
| 32 |
+
ids: list[int]
|
| 33 |
+
place_ids: list[int]
|
| 34 |
+
role_ids: list[int]
|
| 35 |
+
tokens: list[str] = field(default_factory=list)
|
| 36 |
+
|
| 37 |
+
@property
|
| 38 |
+
def input_ids(self) -> list[int]:
|
| 39 |
+
return self.ids
|
| 40 |
+
|
| 41 |
+
def __len__(self) -> int:
|
| 42 |
+
return len(self.ids)
|
| 43 |
+
|
| 44 |
+
def __iter__(self):
|
| 45 |
+
return iter(self.ids)
|
| 46 |
+
|
| 47 |
+
def __post_init__(self) -> None:
|
| 48 |
+
if not (len(self.ids) == len(self.place_ids) == len(self.role_ids)):
|
| 49 |
+
raise ValueError("Fusion tokenizer streams must have equal length")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def build_tokenizer() -> Any:
|
| 53 |
+
"""Build a byte-level BPE tokenizer with explicit lossless boundaries."""
|
| 54 |
+
from tokenizers import Regex, Tokenizer, decoders, models, pre_tokenizers
|
| 55 |
+
|
| 56 |
+
tokenizer = Tokenizer(models.BPE(unk_token="<|unk|>"))
|
| 57 |
+
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
| 58 |
+
[
|
| 59 |
+
pre_tokenizers.Split(
|
| 60 |
+
Regex(r"\s+|\d|[+\-*/=()]|[^\s\d+\-*/=()]+"),
|
| 61 |
+
behavior="isolated",
|
| 62 |
+
),
|
| 63 |
+
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False),
|
| 64 |
+
]
|
| 65 |
+
)
|
| 66 |
+
tokenizer.decoder = decoders.ByteLevel()
|
| 67 |
+
return tokenizer
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class FusionTokenizer:
|
| 71 |
+
"""Runtime wrapper adding LSD-first digit streams to a trained BPE tokenizer."""
|
| 72 |
+
|
| 73 |
+
_digit_span_re = re.compile(r"\d+")
|
| 74 |
+
|
| 75 |
+
def __init__(self, tokenizer: Any):
|
| 76 |
+
self.tokenizer = tokenizer
|
| 77 |
+
self._digit_token_ids = frozenset(
|
| 78 |
+
token_id
|
| 79 |
+
for digit in "0123456789"
|
| 80 |
+
if (token_id := self.tokenizer.token_to_id(digit)) is not None
|
| 81 |
+
)
|
| 82 |
+
self._digit_id_to_text = {
|
| 83 |
+
int(self.tokenizer.token_to_id(digit)): digit
|
| 84 |
+
for digit in "0123456789"
|
| 85 |
+
if self.tokenizer.token_to_id(digit) is not None
|
| 86 |
+
}
|
| 87 |
+
self._equals_id = self.tokenizer.token_to_id("=")
|
| 88 |
+
self._special_token_ids = frozenset(
|
| 89 |
+
token_id
|
| 90 |
+
for token in SPECIAL_TOKENS
|
| 91 |
+
if (token_id := self.tokenizer.token_to_id(token)) is not None
|
| 92 |
+
)
|
| 93 |
+
if len(self._digit_token_ids) != 10:
|
| 94 |
+
raise ValueError("Tokenizer vocabulary must contain atomic digit tokens 0-9")
|
| 95 |
+
if self._equals_id is None:
|
| 96 |
+
raise ValueError("Tokenizer vocabulary must contain an atomic '=' token")
|
| 97 |
+
|
| 98 |
+
def __getattr__(self, name: str) -> Any:
|
| 99 |
+
return getattr(self.tokenizer, name)
|
| 100 |
+
|
| 101 |
+
@property
|
| 102 |
+
def digit_token_ids(self) -> frozenset[int]:
|
| 103 |
+
return self._digit_token_ids
|
| 104 |
+
|
| 105 |
+
@property
|
| 106 |
+
def special_token_ids(self) -> frozenset[int]:
|
| 107 |
+
return self._special_token_ids
|
| 108 |
+
|
| 109 |
+
def get_vocab_size(self, with_added_tokens: bool = True) -> int:
|
| 110 |
+
return int(self.tokenizer.get_vocab_size(with_added_tokens=with_added_tokens))
|
| 111 |
+
|
| 112 |
+
def get_vocab(self, with_added_tokens: bool = True) -> dict[str, int]:
|
| 113 |
+
return self.tokenizer.get_vocab(with_added_tokens=with_added_tokens)
|
| 114 |
+
|
| 115 |
+
def token_to_id(self, token: str) -> int | None:
|
| 116 |
+
return self.tokenizer.token_to_id(token)
|
| 117 |
+
|
| 118 |
+
def id_to_token(self, token_id: int) -> str | None:
|
| 119 |
+
return self.tokenizer.id_to_token(int(token_id))
|
| 120 |
+
|
| 121 |
+
@classmethod
|
| 122 |
+
def _reverse_digit_spans(cls, text: str) -> str:
|
| 123 |
+
return cls._digit_span_re.sub(lambda match: match.group(0)[::-1], text)
|
| 124 |
+
|
| 125 |
+
def _decode_token_piece(self, token_id: int) -> str:
|
| 126 |
+
return self.tokenizer.decode([int(token_id)], skip_special_tokens=False)
|
| 127 |
+
|
| 128 |
+
@staticmethod
|
| 129 |
+
def _is_equation_whitespace(piece: str) -> bool:
|
| 130 |
+
return bool(piece) and piece.isspace() and "\n" not in piece and "\r" not in piece
|
| 131 |
+
|
| 132 |
+
def _is_equation_piece(self, token_id: int, piece: str) -> bool:
|
| 133 |
+
if token_id in self._special_token_ids:
|
| 134 |
+
return False
|
| 135 |
+
if token_id in self._digit_token_ids:
|
| 136 |
+
return True
|
| 137 |
+
if self._is_equation_whitespace(piece):
|
| 138 |
+
return True
|
| 139 |
+
return len(piece) == 1 and piece in set(ARITHMETIC_TOKENS)
|
| 140 |
+
|
| 141 |
+
def _annotate_equation_span(
|
| 142 |
+
self,
|
| 143 |
+
ids: list[int],
|
| 144 |
+
pieces: list[str],
|
| 145 |
+
start: int,
|
| 146 |
+
end: int,
|
| 147 |
+
role_ids: list[int],
|
| 148 |
+
) -> None:
|
| 149 |
+
equals_positions = [
|
| 150 |
+
index
|
| 151 |
+
for index in range(start, end)
|
| 152 |
+
if ids[index] == self._equals_id
|
| 153 |
+
]
|
| 154 |
+
if len(equals_positions) != 1:
|
| 155 |
+
return
|
| 156 |
+
equals_position = equals_positions[0]
|
| 157 |
+
|
| 158 |
+
digit_runs: list[tuple[int, int]] = []
|
| 159 |
+
index = start
|
| 160 |
+
while index < end:
|
| 161 |
+
if ids[index] not in self._digit_token_ids:
|
| 162 |
+
index += 1
|
| 163 |
+
continue
|
| 164 |
+
run_start = index
|
| 165 |
+
while index < end and ids[index] in self._digit_token_ids:
|
| 166 |
+
index += 1
|
| 167 |
+
digit_runs.append((run_start, index))
|
| 168 |
+
|
| 169 |
+
operand_runs = [(a, b) for a, b in digit_runs if b <= equals_position]
|
| 170 |
+
result_runs = [(a, b) for a, b in digit_runs if a > equals_position]
|
| 171 |
+
if not operand_runs or not result_runs or len(operand_runs) > MAX_OPERAND_ROLES:
|
| 172 |
+
return
|
| 173 |
+
|
| 174 |
+
for index in range(start, end):
|
| 175 |
+
if self._is_equation_whitespace(pieces[index]):
|
| 176 |
+
role_ids[index] = SPACE_ROLE_ID
|
| 177 |
+
|
| 178 |
+
for role, (run_start, run_end) in enumerate(operand_runs, start=1):
|
| 179 |
+
for index in range(run_start, run_end):
|
| 180 |
+
role_ids[index] = role
|
| 181 |
+
for run_start, run_end in result_runs:
|
| 182 |
+
for index in range(run_start, run_end):
|
| 183 |
+
role_ids[index] = RESULT_ROLE_ID
|
| 184 |
+
|
| 185 |
+
def annotate_ids(self, ids: Iterable[int]) -> tuple[list[int], list[int]]:
|
| 186 |
+
input_ids = [int(token_id) for token_id in ids]
|
| 187 |
+
place_ids = [0] * len(input_ids)
|
| 188 |
+
role_ids = [0] * len(input_ids)
|
| 189 |
+
pieces = [self._decode_token_piece(token_id) for token_id in input_ids]
|
| 190 |
+
|
| 191 |
+
index = 0
|
| 192 |
+
while index < len(input_ids):
|
| 193 |
+
if input_ids[index] not in self._digit_token_ids:
|
| 194 |
+
index += 1
|
| 195 |
+
continue
|
| 196 |
+
run_start = index
|
| 197 |
+
while index < len(input_ids) and input_ids[index] in self._digit_token_ids:
|
| 198 |
+
offset = index - run_start + 1
|
| 199 |
+
place_ids[index] = min(offset, PLACE_OVERFLOW_ID)
|
| 200 |
+
index += 1
|
| 201 |
+
|
| 202 |
+
span_start: int | None = None
|
| 203 |
+
for index, (token_id, piece) in enumerate(zip(input_ids, pieces, strict=True)):
|
| 204 |
+
if self._is_equation_piece(token_id, piece):
|
| 205 |
+
if span_start is None:
|
| 206 |
+
span_start = index
|
| 207 |
+
continue
|
| 208 |
+
if span_start is not None:
|
| 209 |
+
self._annotate_equation_span(input_ids, pieces, span_start, index, role_ids)
|
| 210 |
+
span_start = None
|
| 211 |
+
if span_start is not None:
|
| 212 |
+
self._annotate_equation_span(input_ids, pieces, span_start, len(input_ids), role_ids)
|
| 213 |
+
|
| 214 |
+
return place_ids, role_ids
|
| 215 |
+
|
| 216 |
+
def encode(self, text: str, *args, **kwargs) -> FusionEncoding:
|
| 217 |
+
transformed = self._reverse_digit_spans(text)
|
| 218 |
+
encoding = self.tokenizer.encode(transformed, *args, **kwargs)
|
| 219 |
+
ids = [int(token_id) for token_id in encoding.ids]
|
| 220 |
+
place_ids, role_ids = self.annotate_ids(ids)
|
| 221 |
+
return FusionEncoding(
|
| 222 |
+
ids=ids,
|
| 223 |
+
place_ids=place_ids,
|
| 224 |
+
role_ids=role_ids,
|
| 225 |
+
tokens=list(getattr(encoding, "tokens", [])),
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
def encode_batch(self, texts: list[str], *args, **kwargs) -> list[FusionEncoding]:
|
| 229 |
+
return [self.encode(text, *args, **kwargs) for text in texts]
|
| 230 |
+
|
| 231 |
+
def decode(
|
| 232 |
+
self,
|
| 233 |
+
token_ids: Iterable[int],
|
| 234 |
+
skip_special_tokens: bool = True,
|
| 235 |
+
) -> str:
|
| 236 |
+
pieces: list[str] = []
|
| 237 |
+
text_ids: list[int] = []
|
| 238 |
+
digit_buffer: list[str] = []
|
| 239 |
+
|
| 240 |
+
def flush_text() -> None:
|
| 241 |
+
if text_ids:
|
| 242 |
+
pieces.append(
|
| 243 |
+
self.tokenizer.decode(
|
| 244 |
+
text_ids,
|
| 245 |
+
skip_special_tokens=skip_special_tokens,
|
| 246 |
+
)
|
| 247 |
+
)
|
| 248 |
+
text_ids.clear()
|
| 249 |
+
|
| 250 |
+
def flush_digits() -> None:
|
| 251 |
+
if digit_buffer:
|
| 252 |
+
pieces.extend(reversed(digit_buffer))
|
| 253 |
+
digit_buffer.clear()
|
| 254 |
+
|
| 255 |
+
for raw_id in token_ids:
|
| 256 |
+
token_id = int(raw_id)
|
| 257 |
+
if token_id in self._digit_token_ids:
|
| 258 |
+
flush_text()
|
| 259 |
+
digit_buffer.append(self._digit_id_to_text[token_id])
|
| 260 |
+
continue
|
| 261 |
+
|
| 262 |
+
flush_digits()
|
| 263 |
+
text_ids.append(token_id)
|
| 264 |
+
|
| 265 |
+
flush_text()
|
| 266 |
+
flush_digits()
|
| 267 |
+
return "".join(pieces)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def build_trainer(vocab_size: int, min_frequency: int) -> Any:
|
| 271 |
+
from tokenizers import pre_tokenizers, trainers
|
| 272 |
+
|
| 273 |
+
return trainers.BpeTrainer(
|
| 274 |
+
vocab_size=vocab_size,
|
| 275 |
+
min_frequency=min_frequency,
|
| 276 |
+
special_tokens=SPECIAL_TOKENS,
|
| 277 |
+
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def tokenizer_files(tokenizer_dir: Path) -> tuple[Path, Path, Path]:
|
| 282 |
+
return (
|
| 283 |
+
tokenizer_dir / "tokenizer.json",
|
| 284 |
+
tokenizer_dir / "vocab.json",
|
| 285 |
+
tokenizer_dir / "merges.txt",
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def validate_tokenizer(tokenizer_dir: Path) -> None:
|
| 290 |
+
tokenizer_json, vocab_path, merges_path = tokenizer_files(tokenizer_dir)
|
| 291 |
+
if not tokenizer_json.exists():
|
| 292 |
+
raise FileNotFoundError(
|
| 293 |
+
f"Missing {tokenizer_json}. Retrain with train_tokenizer.py so the "
|
| 294 |
+
"whitespace and digit boundary rules are preserved."
|
| 295 |
+
)
|
| 296 |
+
if vocab_path.exists():
|
| 297 |
+
with vocab_path.open("r", encoding="utf-8") as f:
|
| 298 |
+
vocab = json.load(f)
|
| 299 |
+
else:
|
| 300 |
+
with tokenizer_json.open("r", encoding="utf-8") as f:
|
| 301 |
+
tokenizer_data = json.load(f)
|
| 302 |
+
vocab = tokenizer_data.get("model", {}).get("vocab")
|
| 303 |
+
if not isinstance(vocab, dict):
|
| 304 |
+
raise FileNotFoundError(f"Missing vocab.json and no embedded vocab in {tokenizer_json}")
|
| 305 |
+
|
| 306 |
+
max_id = max(vocab.values())
|
| 307 |
+
if max_id > 65_535:
|
| 308 |
+
raise ValueError(f"Tokenizer max id {max_id} does not fit in uint16")
|
| 309 |
+
if vocab.get("<|endoftext|>") != EOT_ID:
|
| 310 |
+
raise ValueError(
|
| 311 |
+
f"Expected <|endoftext|> id {EOT_ID}, "
|
| 312 |
+
f"got {vocab.get('<|endoftext|>')}"
|
| 313 |
+
)
|
| 314 |
+
missing = [
|
| 315 |
+
token
|
| 316 |
+
for token in (*[str(value) for value in range(10)], *ARITHMETIC_TOKENS)
|
| 317 |
+
if token not in vocab
|
| 318 |
+
]
|
| 319 |
+
if missing:
|
| 320 |
+
raise ValueError(f"Tokenizer missing required atomic tokens: {missing}")
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def load_tokenizer(tokenizer_dir: Path) -> Any:
|
| 324 |
+
from tokenizers import Tokenizer
|
| 325 |
+
|
| 326 |
+
validate_tokenizer(tokenizer_dir)
|
| 327 |
+
tokenizer_json, _, _ = tokenizer_files(tokenizer_dir)
|
| 328 |
+
return FusionTokenizer(Tokenizer.from_file(str(tokenizer_json)))
|