ucr-max commited on
Commit
271e253
·
verified ·
1 Parent(s): 0b2a1c2

Upload Atom2.7m model

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ bg.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ pipeline_tag: text-generation
6
+ tags:
7
+ - causal-lm
8
+ - gpt
9
+ - small-language-model
10
+ - arithmetic
11
+ - custom-tokenizer
12
+ - custom-code
13
+ - safetensors
14
+ - lm-evaluation-harness
15
+ datasets:
16
+ - openbmb/Ultra-FineWeb
17
+ - HuggingFaceFW/fineweb-edu
18
+ - HuggingFaceTB/finemath
19
+ - HuggingFaceTB/smollm-corpus
20
+ ---
21
+
22
+ ![bg](bg.png)
23
+ # Atom2.7m
24
+
25
+ Atom2.7m is a small decoder-only causal language model trained with a general byte-level BPE tokenizer plus arithmetic-specific digit features. The model has 2,738,880 parameters and uses custom code for both the model and the tokenizer path.
26
+
27
+ ## Model Details
28
+
29
+ - Architecture: decoder-only GPT
30
+ - Parameters: 2,738,880
31
+ - Layers: 5
32
+ - Hidden size: 192
33
+ - Attention heads: 4
34
+ - KV heads: 2
35
+ - Context length: 512
36
+ - Vocabulary size: 4,096
37
+ - Token embeddings: tied input/output embeddings
38
+ - Arithmetic feature embeddings:
39
+ - `place_vocab_size`: 66
40
+ - `role_vocab_size`: 12
41
+
42
+ ## Tokenizer
43
+
44
+ This model should not be evaluated or used with a plain Hugging Face tokenizer path alone. It uses a custom fusion tokenizer implemented in `tokenizer_utils.py`.
45
+
46
+ The tokenizer keeps byte-level BPE for ordinary text, but treats arithmetic sensitive spans specially:
47
+
48
+ - digits `0`-`9` are atomic and never BPE-merged
49
+ - digit spans are emitted least-significant-digit first
50
+ - `+ - * / = ( )` are isolated atomic tokens
51
+ - whitespace is isolated from text
52
+ - `place_ids` are assigned to every digit run
53
+ - `role_ids` are assigned only for strict integer equation spans
54
+
55
+ The model expects aligned `input_ids`, `place_ids`, and `role_ids`.
56
+
57
+ ## Usage
58
+
59
+ ```python
60
+ from pathlib import Path
61
+
62
+ import torch
63
+ from transformers import AutoModelForCausalLM
64
+
65
+ from tokenizer_utils import load_tokenizer
66
+
67
+ model_dir = Path(".")
68
+
69
+ model = AutoModelForCausalLM.from_pretrained(
70
+ model_dir,
71
+ trust_remote_code=True,
72
+ ).eval()
73
+ tokenizer = load_tokenizer(model_dir)
74
+
75
+ text = "12 + 34 ="
76
+ encoding = tokenizer.encode(text)
77
+
78
+ input_ids = torch.tensor([encoding.input_ids])
79
+ place_ids = torch.tensor([encoding.place_ids])
80
+ role_ids = torch.tensor([encoding.role_ids])
81
+
82
+ with torch.no_grad():
83
+ outputs = model(
84
+ input_ids=input_ids,
85
+ place_ids=place_ids,
86
+ role_ids=role_ids,
87
+ )
88
+ ```
89
+
90
+ For correct results, do not rely on `pipeline("text-generation")` unless it is wrapped to provide `place_ids` and `role_ids`.
91
+
92
+ ## Evaluation
93
+
94
+ ### ArithMark 2.0
95
+
96
+ Use the included fusion-aware benchmark script:
97
+
98
+ ```bash
99
+ python benchmark_fusion_arithmark.py \
100
+ --checkpoint . \
101
+ --tokenizer-dir . \
102
+ --data-path arithmark_2.0.jsonl \
103
+ --batch-size 64 \
104
+ --device cuda \
105
+ --output benchmark_results/fusion_arithmark_2.0_results.json
106
+ ```
107
+
108
+ ### lm-evaluation-harness
109
+
110
+ Use the included launcher so the `atom2.7m` model wrapper is registered:
111
+
112
+ ```bash
113
+ python lm_eval_fusion run \
114
+ --model atom2.7m \
115
+ --model_args pretrained=.,tokenizer_dir=. \
116
+ --tasks hellaswag,arc_easy,arc_challenge,piqa \
117
+ --device cuda:0 \
118
+ --batch_size auto \
119
+ --output_path benchmark_results/lm_eval
120
+ ```
121
+
122
+ The wrapper uses `tokenizer_utils.load_tokenizer()` and forwards `place_ids` and `role_ids` to the model.
123
+
124
+ ## Results
125
+
126
+ | Benchmark | Metric | Value |
127
+ | --- | --- | ---: |
128
+ | ArithMark 2.0 | acc | 0.6380 |
129
+ | arc_challenge | acc_norm | 0.2261 |
130
+ | arc_easy | acc_norm | 0.3270 |
131
+ | hellaswag | acc_norm | 0.2733 |
132
+ | piqa | acc_norm | 0.5305 |
133
+
134
+ ## Training Data
135
+
136
+ The pretraining mixture targeted about 3.5B tokens:
137
+
138
+ - Ultra-FineWeb: 900M
139
+ - FineWeb-Edu: 900M
140
+ - FineMath: 450M
141
+ - Cosmopedia-v2: 337.5M
142
+ - UltraData-Math-L2-preview: 337.5M
143
+ - Ultra-FineWeb-L3-en-QA-Synthetic: 225M
144
+ - Synthetic-Arithmetic: 350M
145
+
146
+ Synthetic-Arithmetic is AtomCalc-style canonical integer equation data. The training curriculum is included as `pretraining_curriculum.json`.
147
+
148
+ ## Limitations
149
+
150
+ - This is a very small model and should be treated as an experimental research artifact.
151
+ - The custom tokenizer makes plain `AutoTokenizer` or default `lm_eval --model hf` unsuitable for final reported numbers.
152
+ - Numeric text is represented least-significant-digit first internally.
153
+ - Role annotations intentionally target strict integer equations, not broad math prose, decimals, rationals, or QA formats.
154
+
155
+ ## Files
156
+
157
+ - `model.safetensors`: model weights
158
+ - `config.json`, `config.py`, `configuration_gpt.py`, `model.py`: custom model code
159
+ - `tokenizer.json`, `tokenizer_utils.py`: tokenizer files and fusion wrapper
160
+ - `benchmark_fusion_arithmark.py`: ArithMark evaluation
161
+ - `lm_eval_fusion.py`, `lm_eval_fusion`: lm-eval custom model wrapper
162
+ - `pretraining_curriculum.json`: training curriculum
benchmark_fusion_arithmark.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Score an Atom2.7m checkpoint on ArithMark 2.0."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ from collections import Counter
7
+ from contextlib import nullcontext
8
+ import json
9
+ from pathlib import Path
10
+ import re
11
+ import urllib.request
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from transformers import AutoModelForCausalLM
16
+
17
+ from tokenizer_utils import SPECIAL_TOKENS, FusionTokenizer, load_tokenizer
18
+
19
+
20
+ DATA_URL = (
21
+ "https://huggingface.co/datasets/AxiomicLabs/Arithmark-2.0/"
22
+ "resolve/main/arithmark_2.0.jsonl"
23
+ )
24
+ PAD_ID = SPECIAL_TOKENS.index("<|pad|>")
25
+
26
+
27
+ def ensure_data(path: Path) -> Path:
28
+ if path.exists():
29
+ return path
30
+ path.parent.mkdir(parents=True, exist_ok=True)
31
+ urllib.request.urlretrieve(DATA_URL, path)
32
+ return path
33
+
34
+
35
+ def load_examples(path: Path, *, max_examples: int = 0) -> list[dict]:
36
+ examples = []
37
+ with path.open("r", encoding="utf-8") as handle:
38
+ for line in handle:
39
+ if not line.strip():
40
+ continue
41
+ examples.append(json.loads(line))
42
+ if max_examples > 0 and len(examples) >= max_examples:
43
+ break
44
+ return examples
45
+
46
+
47
+ def _encoded_choice(
48
+ tokenizer: FusionTokenizer,
49
+ context: str,
50
+ ending: str,
51
+ ) -> tuple[list[int], list[int], list[int], int]:
52
+ context_encoding = tokenizer.encode(context)
53
+ full_encoding = tokenizer.encode(context + ending)
54
+ continuation_length = len(full_encoding.input_ids) - len(context_encoding.input_ids)
55
+ return (
56
+ full_encoding.input_ids,
57
+ full_encoding.place_ids,
58
+ full_encoding.role_ids,
59
+ continuation_length,
60
+ )
61
+
62
+
63
+ @torch.inference_mode()
64
+ def evaluate(
65
+ model: torch.nn.Module,
66
+ tokenizer: FusionTokenizer,
67
+ examples: list[dict],
68
+ *,
69
+ device: torch.device,
70
+ batch_size: int,
71
+ dump_failures: bool = False,
72
+ failure_operator_count: int | None = None,
73
+ max_failures: int = 100,
74
+ ) -> dict:
75
+ correct = 0
76
+ total = 0
77
+ by_operator_count: dict[str, list[int]] = {}
78
+ by_topic: dict[str, list[int]] = {}
79
+ failures: list[dict] = []
80
+ failure_summary: Counter[tuple[str, str, str]] = Counter()
81
+ model.eval()
82
+
83
+ for start in range(0, len(examples), batch_size):
84
+ batch_examples = examples[start : start + batch_size]
85
+ encoded = []
86
+ offsets = []
87
+ for example in batch_examples:
88
+ flat_start = len(encoded)
89
+ encoded.extend(
90
+ _encoded_choice(tokenizer, example["ctx"], ending)
91
+ for ending in example["endings"]
92
+ )
93
+ offsets.append((flat_start, len(example["endings"])))
94
+
95
+ max_length = max(len(item[0]) for item in encoded)
96
+ input_ids = torch.full(
97
+ (len(encoded), max_length),
98
+ PAD_ID,
99
+ dtype=torch.long,
100
+ device=device,
101
+ )
102
+ place_ids = torch.zeros_like(input_ids)
103
+ role_ids = torch.zeros_like(input_ids)
104
+ attention_mask = torch.zeros_like(input_ids, dtype=torch.bool)
105
+ lengths = []
106
+ continuation_lengths = []
107
+ for row, (ids, places, roles, continuation_length) in enumerate(encoded):
108
+ length = len(ids)
109
+ input_ids[row, :length] = torch.tensor(ids, device=device)
110
+ place_ids[row, :length] = torch.tensor(places, device=device)
111
+ role_ids[row, :length] = torch.tensor(roles, device=device)
112
+ attention_mask[row, :length] = True
113
+ lengths.append(length)
114
+ continuation_lengths.append(continuation_length)
115
+
116
+ autocast = (
117
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16)
118
+ if device.type == "cuda"
119
+ else nullcontext()
120
+ )
121
+ with autocast:
122
+ logits = model(
123
+ input_ids=input_ids,
124
+ place_ids=place_ids,
125
+ role_ids=role_ids,
126
+ attention_mask=attention_mask,
127
+ ).logits
128
+ log_probs = F.log_softmax(logits.float(), dim=-1)
129
+
130
+ for example_index, example in enumerate(batch_examples):
131
+ flat_start, choice_count = offsets[example_index]
132
+ likelihoods = []
133
+ for choice_index in range(choice_count):
134
+ row = flat_start + choice_index
135
+ length = lengths[row]
136
+ continuation_length = continuation_lengths[row]
137
+ continuation_start = length - continuation_length
138
+ likelihood = 0.0
139
+ for position in range(continuation_start, length):
140
+ likelihood += float(
141
+ log_probs[row, position - 1, input_ids[row, position]].item()
142
+ )
143
+ likelihoods.append(likelihood)
144
+
145
+ prediction = max(range(choice_count), key=likelihoods.__getitem__)
146
+ label = int(example["label"])
147
+ matched = prediction == label
148
+ correct += int(matched)
149
+ total += 1
150
+ metadata = example.get("metadata", {})
151
+ operator_count = str(metadata.get("operator_count", "unknown"))
152
+ topic = str(metadata.get("topic", "unknown"))
153
+ for grouped, key in (
154
+ (by_operator_count, operator_count),
155
+ (by_topic, topic),
156
+ ):
157
+ group = grouped.setdefault(key, [0, 0])
158
+ group[0] += int(matched)
159
+ group[1] += 1
160
+
161
+ if not matched and dump_failures:
162
+ op_count_int = None
163
+ try:
164
+ op_count_int = int(operator_count)
165
+ except ValueError:
166
+ pass
167
+ if failure_operator_count is None or op_count_int == failure_operator_count:
168
+ context = str(example["ctx"]).strip()
169
+ expression = context[:-1].strip() if context.endswith("=") else context
170
+ operands = [int(value) for value in re.findall(r"\d+", expression)]
171
+ operator = "".join(re.findall(r"[+\-*/]", expression))
172
+ predicted_answer = str(example["endings"][prediction]).strip()
173
+ correct_answer = str(example["endings"][label]).strip()
174
+ width = max((len(str(value)) for value in operands), default=0)
175
+ failure_summary[(topic, operator, f"width={width}")] += 1
176
+ if len(failures) < max_failures:
177
+ failures.append(
178
+ {
179
+ "ctx": context,
180
+ "topic": topic,
181
+ "operator_count": operator_count,
182
+ "operator": operator,
183
+ "operands": operands,
184
+ "max_operand_digits": width,
185
+ "correct_answer": correct_answer,
186
+ "predicted_answer": predicted_answer,
187
+ "choices": [str(value).strip() for value in example["endings"]],
188
+ "choice_scores": [round(value, 4) for value in likelihoods],
189
+ "score_margin_correct_minus_predicted": round(
190
+ likelihoods[label] - likelihoods[prediction],
191
+ 4,
192
+ ),
193
+ }
194
+ )
195
+
196
+ results = {
197
+ "benchmark": "arithmark_2.0",
198
+ "model_type": "atom2.7m",
199
+ "accuracy": correct / max(total, 1),
200
+ "correct": correct,
201
+ "total": total,
202
+ "by_operator_count": {
203
+ key: {
204
+ "accuracy": values[0] / max(values[1], 1),
205
+ "correct": values[0],
206
+ "total": values[1],
207
+ }
208
+ for key, values in sorted(by_operator_count.items())
209
+ },
210
+ "by_topic": {
211
+ key: {
212
+ "accuracy": values[0] / max(values[1], 1),
213
+ "correct": values[0],
214
+ "total": values[1],
215
+ }
216
+ for key, values in sorted(by_topic.items())
217
+ },
218
+ }
219
+ if dump_failures:
220
+ results["failure_summary"] = {
221
+ "|".join(key): value
222
+ for key, value in failure_summary.most_common()
223
+ }
224
+ results["failures"] = failures
225
+ return results
226
+
227
+
228
+ def parse_args() -> argparse.Namespace:
229
+ parser = argparse.ArgumentParser(description=__doc__)
230
+ parser.add_argument("--checkpoint", type=Path, default=Path("outputs/fusion_run/final_model"))
231
+ parser.add_argument("--tokenizer-dir", type=Path, default=Path("tokenizer_4k"))
232
+ parser.add_argument("--data-path", type=Path, default=Path("arithmark_2.0.jsonl"))
233
+ parser.add_argument("--batch-size", type=int, default=64)
234
+ parser.add_argument("--device", default="auto")
235
+ parser.add_argument("--output", type=Path)
236
+ parser.add_argument(
237
+ "--max-examples",
238
+ type=int,
239
+ default=0,
240
+ help="Evaluate only the first N examples. Default evaluates all examples.",
241
+ )
242
+ parser.add_argument(
243
+ "--dump-failures",
244
+ action="store_true",
245
+ help="Include incorrectly scored examples and grouped failure summary.",
246
+ )
247
+ parser.add_argument(
248
+ "--failure-operator-count",
249
+ type=int,
250
+ default=None,
251
+ help="Only dump failures with this operator count, e.g. 1 for easy examples.",
252
+ )
253
+ parser.add_argument("--max-failures", type=int, default=100)
254
+ return parser.parse_args()
255
+
256
+
257
+ def main() -> None:
258
+ args = parse_args()
259
+ if args.device == "auto":
260
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
261
+ else:
262
+ device = torch.device(args.device)
263
+
264
+ data_path = ensure_data(args.data_path)
265
+ examples = load_examples(data_path, max_examples=args.max_examples)
266
+ model = AutoModelForCausalLM.from_pretrained(
267
+ args.checkpoint,
268
+ trust_remote_code=True,
269
+ ).to(device)
270
+ tokenizer = load_tokenizer(args.tokenizer_dir)
271
+ results = evaluate(
272
+ model,
273
+ tokenizer,
274
+ examples,
275
+ device=device,
276
+ batch_size=args.batch_size,
277
+ dump_failures=args.dump_failures,
278
+ failure_operator_count=args.failure_operator_count,
279
+ max_failures=args.max_failures,
280
+ )
281
+ print(json.dumps(results, indent=2, sort_keys=True))
282
+ if args.output:
283
+ args.output.parent.mkdir(parents=True, exist_ok=True)
284
+ args.output.write_text(
285
+ json.dumps(results, indent=2, sort_keys=True) + "\n",
286
+ encoding="utf-8",
287
+ )
288
+
289
+
290
+ if __name__ == "__main__":
291
+ main()
bg.png ADDED

Git LFS Details

  • SHA256: bae1d107bc4c5ce774ff4112511336df26c07036ad65ec37e1560fdeea982930
  • Pointer size: 132 Bytes
  • Size of remote file: 3.08 MB
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "GPTForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "config.GPTConfig",
7
+ "AutoModelForCausalLM": "model.GPTForCausalLM"
8
+ },
9
+ "block_size": 512,
10
+ "dtype": "float32",
11
+ "head_dim": 48,
12
+ "hidden_size": 192,
13
+ "intermediate_size": 480,
14
+ "labels_are_shifted": true,
15
+ "max_position_embeddings": 512,
16
+ "model_type": "gpt",
17
+ "num_attention_heads": 4,
18
+ "num_hidden_layers": 5,
19
+ "num_key_value_heads": 2,
20
+ "place_vocab_size": 66,
21
+ "rms_norm_eps": 1e-06,
22
+ "role_vocab_size": 12,
23
+ "rope_theta": 5000.0,
24
+ "transformers_version": "4.57.6",
25
+ "use_place_embeddings": true,
26
+ "use_role_embeddings": true,
27
+ "vocab_size": 4096,
28
+ "xsa_projection": true
29
+ }
config.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Environment-driven training configuration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import math
7
+ import uuid
8
+ from dataclasses import dataclass, field
9
+ from pathlib import Path
10
+
11
+ from transformers import PretrainedConfig
12
+
13
+
14
+ DEFAULT_VOCAB_SIZE = 4096
15
+ DEFAULT_HIDDEN_SIZE = 192
16
+ DEFAULT_NUM_HIDDEN_LAYERS = 5
17
+ DEFAULT_NUM_ATTENTION_HEADS = 4
18
+ DEFAULT_NUM_KEY_VALUE_HEADS = 2
19
+ DEFAULT_HEAD_DIM = DEFAULT_HIDDEN_SIZE // DEFAULT_NUM_ATTENTION_HEADS
20
+ DEFAULT_INTERMEDIATE_SIZE = DEFAULT_HIDDEN_SIZE * 5 // 2
21
+ DEFAULT_BLOCK_SIZE = 512
22
+ DEFAULT_ROPE_THETA = 5000.0
23
+ DEFAULT_PLACE_VOCAB_SIZE = 66
24
+ DEFAULT_ROLE_VOCAB_SIZE = 12
25
+
26
+
27
+ class GPTConfig(PretrainedConfig):
28
+ model_type = "gpt"
29
+
30
+ def __init__(
31
+ self,
32
+ vocab_size: int = DEFAULT_VOCAB_SIZE,
33
+ hidden_size: int = DEFAULT_HIDDEN_SIZE,
34
+ num_hidden_layers: int = DEFAULT_NUM_HIDDEN_LAYERS,
35
+ num_attention_heads: int = DEFAULT_NUM_ATTENTION_HEADS,
36
+ num_key_value_heads: int | None = DEFAULT_NUM_KEY_VALUE_HEADS,
37
+ intermediate_size: int | None = DEFAULT_INTERMEDIATE_SIZE,
38
+ head_dim: int | None = None,
39
+ block_size: int = DEFAULT_BLOCK_SIZE,
40
+ rope_theta: float = DEFAULT_ROPE_THETA,
41
+ rms_norm_eps: float = 1e-6,
42
+ xsa_projection: bool = True,
43
+ tie_word_embeddings: bool = True,
44
+ labels_are_shifted: bool = False,
45
+ use_place_embeddings: bool = True,
46
+ use_role_embeddings: bool = True,
47
+ place_vocab_size: int = DEFAULT_PLACE_VOCAB_SIZE,
48
+ role_vocab_size: int = DEFAULT_ROLE_VOCAB_SIZE,
49
+ **kwargs,
50
+ ):
51
+ if num_key_value_heads is None:
52
+ num_key_value_heads = num_attention_heads
53
+ if head_dim is None:
54
+ if hidden_size % num_attention_heads != 0:
55
+ raise ValueError("hidden_size must be divisible by num_attention_heads")
56
+ head_dim = hidden_size // num_attention_heads
57
+ if intermediate_size is None:
58
+ intermediate_size = hidden_size * 4
59
+ if num_attention_heads % num_key_value_heads != 0:
60
+ raise ValueError("num_attention_heads must be divisible by num_key_value_heads")
61
+ if head_dim % 2 != 0:
62
+ raise ValueError("head_dim must be even for RoPE")
63
+
64
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
65
+ self.vocab_size = int(vocab_size)
66
+ self.hidden_size = int(hidden_size)
67
+ self.num_hidden_layers = int(num_hidden_layers)
68
+ self.num_attention_heads = int(num_attention_heads)
69
+ self.num_key_value_heads = int(num_key_value_heads)
70
+ self.intermediate_size = int(intermediate_size)
71
+ self.head_dim = int(head_dim)
72
+ self.block_size = int(block_size)
73
+ self.max_position_embeddings = int(block_size)
74
+ self.rope_theta = float(rope_theta)
75
+ self.rms_norm_eps = float(rms_norm_eps)
76
+ self.xsa_projection = bool(xsa_projection)
77
+ self.labels_are_shifted = bool(labels_are_shifted)
78
+ self.use_place_embeddings = bool(use_place_embeddings)
79
+ self.use_role_embeddings = bool(use_role_embeddings)
80
+ self.place_vocab_size = int(place_vocab_size)
81
+ self.role_vocab_size = int(role_vocab_size)
82
+
83
+
84
+ def _bool_env(name: str, default: bool) -> bool:
85
+ raw = os.environ.get(name)
86
+ if raw is None:
87
+ return default
88
+ return raw.strip().lower() in {"1", "true", "yes", "on"}
89
+
90
+
91
+ def _path_env(name: str, default: str) -> str:
92
+ return str(Path(os.environ.get(name, default)).expanduser())
93
+
94
+
95
+ @dataclass
96
+ class Hyperparameters:
97
+ data_dir: str = field(default_factory=lambda: _path_env("DATA_DIR", "."))
98
+ tokenized_dir: str = field(default_factory=lambda: _path_env("TOKENIZED_DIR", "tokenized2"))
99
+ tokenizer_dir: str = field(default_factory=lambda: _path_env("TOKENIZER_DIR", "tokenizer_4k"))
100
+ tokenizer_path: str = field(default_factory=lambda: os.environ.get("TOKENIZER_PATH", ""))
101
+ curriculum_path: str = field(default_factory=lambda: os.environ.get("CURRICULUM_PATH", ""))
102
+ mix_weights_path: str = field(default_factory=lambda: os.environ.get("MIX_WEIGHTS_PATH", ""))
103
+ run_id: str = field(default_factory=lambda: os.environ.get("RUN_ID", str(uuid.uuid4())))
104
+ seed: int = field(default_factory=lambda: int(os.environ.get("SEED", "1337")))
105
+ rank: int = field(init=False)
106
+
107
+ iterations: int = field(default_factory=lambda: int(os.environ.get("ITERATIONS", "10000")))
108
+ requested_train_tokens: int = field(init=False)
109
+ train_tokens: int = field(init=False)
110
+ decay_start_frac: float = field(default_factory=lambda: float(os.environ.get("DECAY_START_FRAC", "0.7")))
111
+ warmup_steps: int = field(default_factory=lambda: int(os.environ.get("WARMUP_STEPS", "0")))
112
+ lr_warmup_steps: int = field(default_factory=lambda: int(os.environ.get("LR_WARMUP_STEPS", "500")))
113
+ train_batch_tokens: int = field(default_factory=lambda: int(os.environ.get("TRAIN_BATCH_TOKENS", str(DEFAULT_BLOCK_SIZE * 512))))
114
+ train_seq_len: int = field(init=False)
115
+ eval_seq_len: int = field(init=False)
116
+ grad_accum_steps: int = field(default_factory=lambda: int(os.environ.get("GRAD_ACCUM_STEPS", "4")))
117
+ train_log_every: int = field(default_factory=lambda: int(os.environ.get("TRAIN_LOG_EVERY", "100")))
118
+ train_log_dense_steps: int = field(default_factory=lambda: int(os.environ.get("TRAIN_LOG_DENSE_STEPS", "100")))
119
+ train_log_ramp_steps: int = field(
120
+ default_factory=lambda: int(
121
+ os.environ.get(
122
+ "TRAIN_LOG_RAMP_STEPS",
123
+ os.environ.get("TRAIN_LOG_FIRST_STEPS", "500"),
124
+ )
125
+ )
126
+ )
127
+
128
+ val_batch_tokens: int = field(default_factory=lambda: int(os.environ.get("VAL_BATCH_TOKENS", str(DEFAULT_BLOCK_SIZE * 256))))
129
+ val_loss_every: int = field(default_factory=lambda: int(os.environ.get("VAL_LOSS_EVERY", "1000")))
130
+ val_max_tokens: int = field(default_factory=lambda: int(os.environ.get("VAL_MAX_TOKENS", "10_000_000")))
131
+
132
+ vocab_size: int = field(default_factory=lambda: int(os.environ.get("VOCAB_SIZE", str(DEFAULT_VOCAB_SIZE))))
133
+ hidden_size: int = field(default_factory=lambda: int(os.environ.get("HIDDEN_SIZE", os.environ.get("MODEL_DIM", str(DEFAULT_HIDDEN_SIZE)))))
134
+ num_hidden_layers: int = field(default_factory=lambda: int(os.environ.get("NUM_HIDDEN_LAYERS", os.environ.get("NUM_LAYERS", str(DEFAULT_NUM_HIDDEN_LAYERS)))))
135
+ num_attention_heads: int = field(default_factory=lambda: int(os.environ.get("NUM_ATTENTION_HEADS", os.environ.get("NUM_HEADS", str(DEFAULT_NUM_ATTENTION_HEADS)))))
136
+ num_key_value_heads: int = field(default_factory=lambda: int(os.environ.get("NUM_KEY_VALUE_HEADS", os.environ.get("NUM_KV_HEADS", str(DEFAULT_NUM_KEY_VALUE_HEADS)))))
137
+ head_dim: int = field(init=False)
138
+ intermediate_size: int = field(default_factory=lambda: int(os.environ.get("INTERMEDIATE_SIZE", os.environ.get("INTERMEDIATE", str(DEFAULT_INTERMEDIATE_SIZE)))))
139
+ block_size: int = field(default_factory=lambda: int(os.environ.get("BLOCK_SIZE", str(DEFAULT_BLOCK_SIZE))))
140
+ rope_theta: float = field(default_factory=lambda: float(os.environ.get("ROPE_THETA", os.environ.get("ROPE_BASE", str(DEFAULT_ROPE_THETA)))))
141
+ rms_norm_eps: float = field(default_factory=lambda: float(os.environ.get("RMS_NORM_EPS", "1e-6")))
142
+ xsa_projection: bool = field(default_factory=lambda: _bool_env("XSA_PROJECTION", True))
143
+ tie_word_embeddings: bool = field(default_factory=lambda: _bool_env("TIE_WORD_EMBEDDINGS", _bool_env("TIE_EMBEDDINGS", True)))
144
+ use_place_embeddings: bool = field(default_factory=lambda: _bool_env("USE_PLACE_EMBEDDINGS", True))
145
+ use_role_embeddings: bool = field(default_factory=lambda: _bool_env("USE_ROLE_EMBEDDINGS", True))
146
+ place_vocab_size: int = field(default_factory=lambda: int(os.environ.get("PLACE_VOCAB_SIZE", str(DEFAULT_PLACE_VOCAB_SIZE))))
147
+ role_vocab_size: int = field(default_factory=lambda: int(os.environ.get("ROLE_VOCAB_SIZE", str(DEFAULT_ROLE_VOCAB_SIZE))))
148
+
149
+ min_lr: float = field(default_factory=lambda: float(os.environ.get("MIN_LR", "0.0")))
150
+ lr: float = field(default_factory=lambda: float(os.environ.get("LR", "0.004")))
151
+ beta1: float = field(default_factory=lambda: float(os.environ.get("BETA1", "0.9")))
152
+ beta2: float = field(default_factory=lambda: float(os.environ.get("BETA2", "0.95")))
153
+ adam_eps: float = field(default_factory=lambda: float(os.environ.get("ADAM_EPS", "1e-8")))
154
+ weight_decay: float = field(default_factory=lambda: float(os.environ.get("WEIGHT_DECAY", "0.005")))
155
+
156
+ compile_model: bool = field(default_factory=lambda: _bool_env("COMPILE_MODEL", True))
157
+ autocast: bool = field(default_factory=lambda: _bool_env("AUTOCAST", True))
158
+ bf16: bool = field(default_factory=lambda: _bool_env("BF16", True))
159
+ device: str = field(default_factory=lambda: os.environ.get("DEVICE", "auto"))
160
+
161
+ output_dir: str = field(default_factory=lambda: _path_env("OUTPUT_DIR", "outputs"))
162
+ checkpoint_name: str = field(default_factory=lambda: os.environ.get("CHECKPOINT_NAME", "final_model"))
163
+ logfile: str = field(init=False)
164
+ model_path: str = field(init=False)
165
+ is_main_process: bool = True
166
+ train_files: str = field(init=False)
167
+ val_files: str = field(init=False)
168
+
169
+ def __post_init__(self) -> None:
170
+ self.rank = int(os.environ.get("RANK", "0"))
171
+ if self.rank < 0:
172
+ raise ValueError("RANK must be non-negative")
173
+ self.is_main_process = self.rank == 0
174
+ self.head_dim = int(os.environ.get("HEAD_DIM", str(self.hidden_size // self.num_attention_heads)))
175
+ self.train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", str(self.block_size)))
176
+ self.eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", os.environ.get("TRAIN_SEQ_LEN", str(self.train_seq_len))))
177
+ token_alignment = self.train_seq_len * self.grad_accum_steps
178
+ if self.train_batch_tokens % token_alignment != 0:
179
+ raise ValueError(
180
+ "TRAIN_BATCH_TOKENS must be divisible by TRAIN_SEQ_LEN * GRAD_ACCUM_STEPS"
181
+ )
182
+ requested_train_tokens = int(os.environ.get("TRAIN_TOKENS", "0"))
183
+ self.requested_train_tokens = requested_train_tokens or self.iterations * self.train_batch_tokens
184
+ if self.requested_train_tokens <= 0:
185
+ raise ValueError("TRAIN_TOKENS must be positive")
186
+ self.train_tokens = self.requested_train_tokens - (self.requested_train_tokens % token_alignment)
187
+ if self.train_tokens <= 0:
188
+ raise ValueError(
189
+ "TRAIN_TOKENS must provide at least TRAIN_SEQ_LEN * GRAD_ACCUM_STEPS tokens"
190
+ )
191
+ self.iterations = math.ceil(self.train_tokens / self.train_batch_tokens)
192
+ tokenized = Path(self.tokenized_dir)
193
+ self.train_files = os.environ.get("TRAIN_FILES", str(tokenized / "*" / "shard_*.bin"))
194
+ self.val_files = os.environ.get("VAL_FILES", os.environ.get("TRAIN_FILES", self.train_files))
195
+ explicit_legacy_mix = bool(os.environ.get("MIX_WEIGHTS_PATH"))
196
+ if not self.curriculum_path and not explicit_legacy_mix:
197
+ tokenized_curriculum = tokenized / "curriculum.json"
198
+ default_curriculum = Path("pretraining_curriculum.json")
199
+ if tokenized_curriculum.exists():
200
+ self.curriculum_path = str(tokenized_curriculum)
201
+ elif default_curriculum.exists():
202
+ self.curriculum_path = str(default_curriculum)
203
+ if not self.mix_weights_path and not self.curriculum_path:
204
+ mix_weights_path = tokenized / "mix_weights.json"
205
+ self.mix_weights_path = str(mix_weights_path) if mix_weights_path.exists() else ""
206
+ if not self.tokenizer_path:
207
+ tok_dir = Path(self.tokenizer_dir)
208
+ json_path = tok_dir / "tokenizer.json"
209
+ self.tokenizer_path = str(json_path if json_path.exists() else tok_dir)
210
+ out = Path(self.output_dir)
211
+ self.logfile = os.environ.get("LOGFILE", str(out / "logs" / f"{self.run_id}.txt"))
212
+ self.model_path = os.environ.get("MODEL_PATH", str(out / self.checkpoint_name))
213
+
214
+ def to_model_config(self) -> GPTConfig:
215
+ return GPTConfig(
216
+ vocab_size=self.vocab_size,
217
+ hidden_size=self.hidden_size,
218
+ num_hidden_layers=self.num_hidden_layers,
219
+ num_attention_heads=self.num_attention_heads,
220
+ num_key_value_heads=self.num_key_value_heads,
221
+ head_dim=self.head_dim,
222
+ intermediate_size=self.intermediate_size,
223
+ block_size=self.block_size,
224
+ rope_theta=self.rope_theta,
225
+ rms_norm_eps=self.rms_norm_eps,
226
+ xsa_projection=self.xsa_projection,
227
+ tie_word_embeddings=self.tie_word_embeddings,
228
+ use_place_embeddings=self.use_place_embeddings,
229
+ use_role_embeddings=self.use_role_embeddings,
230
+ place_vocab_size=self.place_vocab_size,
231
+ role_vocab_size=self.role_vocab_size,
232
+ labels_are_shifted=True,
233
+ )
configuration_gpt.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Exports for the GPT model configuration.
2
+
3
+ New code should import these from :mod:`GPT.config`.
4
+ """
5
+
6
+ from .config import (
7
+ DEFAULT_BLOCK_SIZE,
8
+ DEFAULT_HEAD_DIM,
9
+ DEFAULT_HIDDEN_SIZE,
10
+ DEFAULT_INTERMEDIATE_SIZE,
11
+ DEFAULT_NUM_ATTENTION_HEADS,
12
+ DEFAULT_NUM_HIDDEN_LAYERS,
13
+ DEFAULT_NUM_KEY_VALUE_HEADS,
14
+ DEFAULT_PLACE_VOCAB_SIZE,
15
+ DEFAULT_ROPE_THETA,
16
+ DEFAULT_ROLE_VOCAB_SIZE,
17
+ DEFAULT_VOCAB_SIZE,
18
+ GPTConfig,
19
+ )
20
+
21
+
22
+ __all__ = [
23
+ "DEFAULT_BLOCK_SIZE",
24
+ "DEFAULT_HEAD_DIM",
25
+ "DEFAULT_HIDDEN_SIZE",
26
+ "DEFAULT_INTERMEDIATE_SIZE",
27
+ "DEFAULT_NUM_ATTENTION_HEADS",
28
+ "DEFAULT_NUM_HIDDEN_LAYERS",
29
+ "DEFAULT_NUM_KEY_VALUE_HEADS",
30
+ "DEFAULT_PLACE_VOCAB_SIZE",
31
+ "DEFAULT_ROPE_THETA",
32
+ "DEFAULT_ROLE_VOCAB_SIZE",
33
+ "DEFAULT_VOCAB_SIZE",
34
+ "GPTConfig",
35
+ ]
lm_eval_fusion ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """Run lm-eval with the local Atom2.7m model registered."""
3
+
4
+ import lm_eval_fusion # noqa: F401
5
+ from lm_eval.__main__ import cli_evaluate
6
+
7
+
8
+ if __name__ == "__main__":
9
+ cli_evaluate()
lm_eval_fusion.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """lm-eval wrapper for Atom2.7m checkpoints.
2
+
3
+ The standard ``hf`` lm-eval model does not use the fusion tokenizer wrapper and
4
+ does not pass arithmetic feature streams. This model keeps lm-eval's
5
+ log-likelihood interface while encoding with ``tokenizer_utils.load_tokenizer``
6
+ and forwarding ``place_ids`` and ``role_ids``.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from contextlib import nullcontext
12
+ from pathlib import Path
13
+ from typing import Any
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from lm_eval.api.model import LM
18
+ from lm_eval.api.registry import register_model
19
+ from tqdm import tqdm
20
+ from transformers import AutoModelForCausalLM
21
+
22
+ from tokenizer_utils import EOT_ID, FusionTokenizer, load_tokenizer
23
+
24
+
25
+ def _parse_bool(value: Any, default: bool = False) -> bool:
26
+ if value is None:
27
+ return default
28
+ if isinstance(value, bool):
29
+ return value
30
+ return str(value).strip().lower() in {"1", "true", "yes", "on"}
31
+
32
+
33
+ def _parse_batch_size(value: int | str | None, max_batch_size: int | None) -> int:
34
+ if value is None:
35
+ return 1
36
+ if isinstance(value, int):
37
+ return value
38
+ text = str(value).strip().lower()
39
+ if text == "auto" or text.startswith("auto:"):
40
+ return int(max_batch_size or 64)
41
+ return int(text)
42
+
43
+
44
+ def _dtype_from_name(value: str | torch.dtype | None) -> torch.dtype | None:
45
+ if value is None or value == "auto":
46
+ return None
47
+ if isinstance(value, torch.dtype):
48
+ return value
49
+ normalized = str(value).replace("torch.", "").lower()
50
+ if normalized in {"bf16", "bfloat16"}:
51
+ return torch.bfloat16
52
+ if normalized in {"fp16", "float16", "half"}:
53
+ return torch.float16
54
+ if normalized in {"fp32", "float32", "float"}:
55
+ return torch.float32
56
+ raise ValueError(f"Unsupported dtype: {value!r}")
57
+
58
+
59
+ @register_model("atom2.7m")
60
+ class FusionGPTLM(LM):
61
+ """Fusion-tokenizer GPT adapter for lm-eval log-likelihood tasks."""
62
+
63
+ def __init__(
64
+ self,
65
+ pretrained: str = "outputs/fusion_run/final_model",
66
+ tokenizer_dir: str = "tokenizer_4k",
67
+ batch_size: int | str | None = 1,
68
+ max_batch_size: int | None = 64,
69
+ max_length: int | None = None,
70
+ device: str | None = "cuda",
71
+ dtype: str | torch.dtype | None = "auto",
72
+ mixed_precision_dtype: str | torch.dtype | None = "auto",
73
+ trust_remote_code: bool | str | None = None,
74
+ **_: Any,
75
+ ) -> None:
76
+ super().__init__()
77
+ del trust_remote_code
78
+ if device is None or device == "auto":
79
+ device = "cuda" if torch.cuda.is_available() else "cpu"
80
+ self._device = torch.device(device)
81
+ self.batch_size = _parse_batch_size(batch_size, max_batch_size)
82
+ self.tokenizer: FusionTokenizer = load_tokenizer(Path(tokenizer_dir))
83
+ self.model = AutoModelForCausalLM.from_pretrained(
84
+ Path(pretrained),
85
+ trust_remote_code=True,
86
+ ).to(self.device)
87
+ model_dtype = _dtype_from_name(dtype)
88
+ if model_dtype is not None:
89
+ self.model = self.model.to(dtype=model_dtype)
90
+ if mixed_precision_dtype == "auto":
91
+ self.mixed_precision_dtype = (
92
+ torch.bfloat16 if self.device.type == "cuda" else None
93
+ )
94
+ else:
95
+ self.mixed_precision_dtype = _dtype_from_name(mixed_precision_dtype)
96
+ self.model.eval()
97
+ self.max_length = int(
98
+ max_length
99
+ or getattr(self.model.config, "block_size", None)
100
+ or getattr(self.model.config, "max_position_embeddings", 512)
101
+ )
102
+
103
+ @property
104
+ def eot_token_id(self) -> int:
105
+ return EOT_ID
106
+
107
+ def tok_encode(
108
+ self,
109
+ string: str,
110
+ add_special_tokens: bool | None = None,
111
+ left_truncate_len: int | None = None,
112
+ **_: Any,
113
+ ) -> list[int]:
114
+ del add_special_tokens
115
+ ids = self.tokenizer.encode(string).input_ids
116
+ if left_truncate_len is not None:
117
+ ids = ids[-left_truncate_len:]
118
+ return ids
119
+
120
+ def tok_decode(self, tokens, skip_special_tokens: bool = True) -> str:
121
+ if isinstance(tokens, int):
122
+ tokens = [tokens]
123
+ return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
124
+
125
+ def _encode_request(
126
+ self,
127
+ context: str,
128
+ continuation: str,
129
+ ) -> tuple[list[int], list[int], list[int], list[int], int]:
130
+ if context == "":
131
+ continuation_encoding = self.tokenizer.encode(continuation)
132
+ ids = [self.eot_token_id] + continuation_encoding.input_ids
133
+ place_ids = [0] + continuation_encoding.place_ids
134
+ role_ids = [0] + continuation_encoding.role_ids
135
+ context_len = 1
136
+ continuation_ids = continuation_encoding.input_ids
137
+ else:
138
+ n_spaces = len(context) - len(context.rstrip())
139
+ if n_spaces > 0:
140
+ continuation = context[-n_spaces:] + continuation
141
+ context = context[:-n_spaces]
142
+ full_encoding = self.tokenizer.encode(context + continuation)
143
+ context_encoding = self.tokenizer.encode(context)
144
+ ids = full_encoding.input_ids
145
+ place_ids = full_encoding.place_ids
146
+ role_ids = full_encoding.role_ids
147
+ context_len = len(context_encoding.input_ids)
148
+ continuation_ids = ids[context_len:]
149
+
150
+ if not continuation_ids:
151
+ raise ValueError("Continuation encoded to zero tokens")
152
+ return ids, place_ids, role_ids, continuation_ids, context_len
153
+
154
+ def loglikelihood(
155
+ self,
156
+ requests: list["Instance"],
157
+ disable_tqdm: bool = False,
158
+ ) -> list[tuple[float, bool]]:
159
+ encoded = [
160
+ self._encode_request(context, continuation)
161
+ for context, continuation in tqdm(
162
+ [req.args for req in requests],
163
+ desc="Fusion tokenizing inputs",
164
+ disable=disable_tqdm,
165
+ )
166
+ ]
167
+ results: list[tuple[float, bool]] = []
168
+ for start in tqdm(
169
+ range(0, len(encoded), self.batch_size),
170
+ desc="Running fusion loglikelihood requests",
171
+ disable=disable_tqdm or self.rank != 0,
172
+ ):
173
+ batch = encoded[start : start + self.batch_size]
174
+ rows = []
175
+ row_places = []
176
+ row_roles = []
177
+ row_targets = []
178
+ row_score_slices = []
179
+ for ids, place_ids, role_ids, continuation_ids, context_len in batch:
180
+ window_start = max(0, len(ids) - (self.max_length + 1))
181
+ window_ids = ids[window_start:]
182
+ window_places = place_ids[window_start:]
183
+ window_roles = role_ids[window_start:]
184
+ input_ids = window_ids[:-1]
185
+ targets = window_ids[1:]
186
+ full_score_start = context_len - 1
187
+ full_score_end = len(ids) - 1
188
+ score_start = max(full_score_start, window_start) - window_start
189
+ score_end = full_score_end - window_start
190
+ if score_end <= score_start:
191
+ raise ValueError("No continuation tokens remain after truncation")
192
+ scored_continuation_ids = continuation_ids[-(score_end - score_start) :]
193
+ rows.append(input_ids)
194
+ row_places.append(window_places[:-1])
195
+ row_roles.append(window_roles[:-1])
196
+ row_targets.append(targets)
197
+ row_score_slices.append((score_start, score_end, scored_continuation_ids))
198
+
199
+ max_len = max(len(row) for row in rows)
200
+ input_tensor = torch.full(
201
+ (len(rows), max_len),
202
+ self.eot_token_id,
203
+ dtype=torch.long,
204
+ device=self.device,
205
+ )
206
+ place_tensor = torch.zeros_like(input_tensor)
207
+ role_tensor = torch.zeros_like(input_tensor)
208
+ attention_mask = torch.zeros_like(input_tensor, dtype=torch.bool)
209
+ target_tensor = torch.full_like(input_tensor, self.eot_token_id)
210
+ for row, (ids, places, roles, targets) in enumerate(
211
+ zip(rows, row_places, row_roles, row_targets, strict=True)
212
+ ):
213
+ length = len(ids)
214
+ input_tensor[row, :length] = torch.tensor(ids, device=self.device)
215
+ place_tensor[row, :length] = torch.tensor(places, device=self.device)
216
+ role_tensor[row, :length] = torch.tensor(roles, device=self.device)
217
+ target_tensor[row, :length] = torch.tensor(targets, device=self.device)
218
+ attention_mask[row, :length] = True
219
+
220
+ autocast = (
221
+ torch.autocast(
222
+ device_type=self.device.type,
223
+ dtype=self.mixed_precision_dtype,
224
+ enabled=self.mixed_precision_dtype is not None,
225
+ )
226
+ if self.device.type == "cuda"
227
+ else nullcontext()
228
+ )
229
+ with torch.inference_mode(), autocast:
230
+ logits = self.model(
231
+ input_ids=input_tensor,
232
+ place_ids=place_tensor,
233
+ role_ids=role_tensor,
234
+ attention_mask=attention_mask,
235
+ ).logits
236
+ log_probs = F.log_softmax(logits.float(), dim=-1)
237
+
238
+ for row, (score_start, score_end, continuation_ids) in enumerate(row_score_slices):
239
+ row_log_probs = log_probs[row, score_start:score_end]
240
+ row_targets_for_score = target_tensor[row, score_start:score_end]
241
+ token_log_probs = torch.gather(
242
+ row_log_probs,
243
+ 1,
244
+ row_targets_for_score.unsqueeze(-1),
245
+ ).squeeze(-1)
246
+ greedy = torch.equal(
247
+ row_log_probs.argmax(dim=-1),
248
+ torch.tensor(continuation_ids, dtype=torch.long, device=self.device),
249
+ )
250
+ results.append((float(token_log_probs.sum().item()), bool(greedy)))
251
+
252
+ return results
253
+
254
+ def loglikelihood_rolling(
255
+ self,
256
+ requests: list["Instance"],
257
+ disable_tqdm: bool = False,
258
+ ) -> list[float]:
259
+ results = []
260
+ for (text,) in tqdm(
261
+ [req.args for req in requests],
262
+ desc="Running fusion rolling loglikelihood",
263
+ disable=disable_tqdm or self.rank != 0,
264
+ ):
265
+ encoding = self.tokenizer.encode(text)
266
+ ids = encoding.input_ids
267
+ places = encoding.place_ids
268
+ roles = encoding.role_ids
269
+ total = 0.0
270
+ start = 0
271
+ while start < len(ids):
272
+ end = min(len(ids), start + self.max_length)
273
+ prefix = [self.eot_token_id] if start == 0 else ids[start - 1 : start]
274
+ chunk_ids = prefix + ids[start:end]
275
+ chunk_places = [0] + places[start:end] if start == 0 else places[start - 1 : end]
276
+ chunk_roles = [0] + roles[start:end] if start == 0 else roles[start - 1 : end]
277
+ input_ids = torch.tensor([chunk_ids[:-1]], dtype=torch.long, device=self.device)
278
+ place_ids = torch.tensor([chunk_places[:-1]], dtype=torch.long, device=self.device)
279
+ role_ids = torch.tensor([chunk_roles[:-1]], dtype=torch.long, device=self.device)
280
+ targets = torch.tensor(chunk_ids[1:], dtype=torch.long, device=self.device)
281
+ with torch.inference_mode():
282
+ logits = self.model(
283
+ input_ids=input_ids,
284
+ place_ids=place_ids,
285
+ role_ids=role_ids,
286
+ ).logits[0]
287
+ log_probs = F.log_softmax(logits.float(), dim=-1)
288
+ total += float(
289
+ torch.gather(log_probs, 1, targets.unsqueeze(-1)).sum().item()
290
+ )
291
+ start = end
292
+ results.append(total)
293
+ return results
294
+
295
+ def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
296
+ raise NotImplementedError(
297
+ "FusionGPTLM currently supports loglikelihood tasks. "
298
+ "Use tasks with multiple-choice/loglikelihood output."
299
+ )
model.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import Tensor
7
+ from torch.nn import functional as F
8
+ from transformers import PreTrainedModel
9
+ from transformers.cache_utils import DynamicCache
10
+ from transformers.generation.utils import GenerationMixin
11
+ from transformers.modeling_outputs import CausalLMOutputWithPast
12
+
13
+ from .config import GPTConfig
14
+
15
+
16
+ CONTROL_TENSOR_NAME_PATTERNS = (
17
+ "scale",
18
+ "gate",
19
+ "gain",
20
+ "norm",
21
+ "ln_",
22
+ "rms",
23
+ )
24
+
25
+
26
+ class CastedLinear(nn.Linear):
27
+ """Store linear params in FP32, cast to activation dtype for matmul."""
28
+
29
+ def forward(self, x: Tensor) -> Tensor:
30
+ weight = self.weight.to(dtype=x.dtype)
31
+ bias = self.bias.to(dtype=x.dtype) if self.bias is not None else None
32
+ return F.linear(x, weight, bias)
33
+
34
+
35
+ def restore_fp32_params(model: nn.Module) -> None:
36
+ """Keep linear weights and control params in FP32 after dtype conversion."""
37
+ for module in model.modules():
38
+ if isinstance(module, CastedLinear):
39
+ module.float()
40
+ for name, param in model.named_parameters():
41
+ if (
42
+ param.ndim < 2
43
+ or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)
44
+ ) and param.dtype != torch.float32:
45
+ param.data = param.data.float()
46
+
47
+
48
+ class RMSNorm(nn.Module):
49
+ def __init__(self, dim, eps=1e-6):
50
+ super().__init__()
51
+ self.eps = eps
52
+ self.weight = nn.Parameter(torch.ones(dim))
53
+
54
+ def forward(self, x):
55
+ rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
56
+ return (x.float() * rms).to(dtype=x.dtype) * self.weight.to(dtype=x.dtype)
57
+
58
+
59
+ def build_rope_inv_freq(head_dim, theta=2500.0):
60
+ return 1.0 / (theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
61
+
62
+
63
+ def precompute_rope_cos_sin(head_dim, seq_len, theta=2500.0):
64
+ freqs = build_rope_inv_freq(head_dim, theta)
65
+ t = torch.arange(seq_len, dtype=torch.float32)
66
+ freqs = torch.outer(t, freqs)
67
+ return freqs.cos(), freqs.sin()
68
+
69
+
70
+ def _apply_rope(x, cos, sin):
71
+ x_float = x.float()
72
+ x_pair = x_float.reshape(*x_float.shape[:-1], -1, 2)
73
+ even = x_pair[..., 0]
74
+ odd = x_pair[..., 1]
75
+ cos = cos.unsqueeze(0).unsqueeze(0)
76
+ sin = sin.unsqueeze(0).unsqueeze(0)
77
+ x_rot = torch.stack((even * cos - odd * sin, even * sin + odd * cos), dim=-1)
78
+ return x_rot.flatten(-2).type_as(x)
79
+
80
+
81
+ def apply_rotary_emb(q, k, freqs_cis):
82
+ cos, sin = freqs_cis
83
+ return _apply_rope(q, cos, sin), _apply_rope(k, cos, sin)
84
+
85
+ class GPTAttention(nn.Module):
86
+ def __init__(self, config, layer_idx):
87
+ super().__init__()
88
+ self.layer_idx = layer_idx
89
+ self.n_head = config.num_attention_heads
90
+ self.n_kv_heads = config.num_key_value_heads
91
+ self.head_dim = config.head_dim
92
+ self.n_rep = self.n_head // self.n_kv_heads
93
+ self.xsa_projection = config.xsa_projection
94
+
95
+ self.q_proj = CastedLinear(config.hidden_size, self.n_head * self.head_dim, bias=False)
96
+ self.k_proj = CastedLinear(config.hidden_size, self.n_kv_heads * self.head_dim, bias=False)
97
+ self.v_proj = CastedLinear(config.hidden_size, self.n_kv_heads * self.head_dim, bias=False)
98
+ self.o_proj = CastedLinear(self.n_head * self.head_dim, config.hidden_size, bias=False)
99
+
100
+ def _xsa_efficient(self, y: Tensor, v_current: Tensor) -> Tensor:
101
+ B, H, T, D = y.shape
102
+ Hkv = v_current.size(1)
103
+ group = H // Hkv
104
+
105
+ y_g = y.reshape(B, Hkv, group, T, D)
106
+ v_n = F.normalize(v_current, dim=-1).unsqueeze(2)
107
+
108
+ proj = (y_g * v_n).sum(dim=-1, keepdim=True) * v_n
109
+ return (y_g - proj).reshape(B, H, T, D)
110
+
111
+ def forward(self, x, freqs_cis, past_key_value=None, use_cache=False, attention_mask=None):
112
+ B, T, _ = x.size()
113
+
114
+ q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
115
+ k_current = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
116
+ v_current = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
117
+
118
+ q, k_current = apply_rotary_emb(q, k_current, freqs_cis)
119
+
120
+ if past_key_value is not None:
121
+ k, v = past_key_value.update(k_current, v_current, self.layer_idx)
122
+ else:
123
+ k, v = k_current, v_current
124
+
125
+ S = k.size(2)
126
+
127
+ is_causal = past_key_value is None or past_key_value.get_seq_length(self.layer_idx) == T
128
+
129
+ attn_mask = None
130
+ if attention_mask is not None:
131
+ key_pad = attention_mask.to(torch.bool)[:, None, None, :]
132
+
133
+ if is_causal and T > 1:
134
+ causal = torch.ones(T, S, dtype=torch.bool, device=x.device).tril(diagonal=S - T)
135
+ attn_mask = key_pad & causal[None, None, :, :]
136
+ else:
137
+ attn_mask = key_pad.expand(B, 1, T, S)
138
+
139
+ is_causal = False
140
+
141
+ y = F.scaled_dot_product_attention(
142
+ q,
143
+ k,
144
+ v,
145
+ attn_mask=attn_mask,
146
+ is_causal=is_causal,
147
+ enable_gqa=(self.n_kv_heads != self.n_head),
148
+ )
149
+
150
+ if self.xsa_projection:
151
+ y = self._xsa_efficient(y, v_current)
152
+
153
+ y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.head_dim)
154
+ return self.o_proj(y)
155
+
156
+
157
+ class GPTMLP(nn.Module):
158
+ def __init__(self, config):
159
+ super().__init__()
160
+ self.w_gate = CastedLinear(config.hidden_size, config.intermediate_size, bias=False)
161
+ self.w_up = CastedLinear(config.hidden_size, config.intermediate_size, bias=False)
162
+ self.w_down = CastedLinear(config.intermediate_size, config.hidden_size, bias=False)
163
+
164
+ def forward(self, x):
165
+ return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))
166
+
167
+
168
+ class GPTBlock(nn.Module):
169
+ def __init__(self, config, layer_idx):
170
+ super().__init__()
171
+ self.ln_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
172
+ self.attn = GPTAttention(config, layer_idx)
173
+ self.ln_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
174
+ self.mlp = GPTMLP(config)
175
+
176
+ def forward(self, x, freqs_cis, past_key_value=None, use_cache=False, attention_mask=None):
177
+ x = x + self.attn(self.ln_1(x), freqs_cis, past_key_value, use_cache, attention_mask=attention_mask)
178
+ x = x + self.mlp(self.ln_2(x))
179
+ return x
180
+
181
+
182
+ class GPTPreTrainedModel(PreTrainedModel):
183
+ config_class = GPTConfig
184
+ base_model_prefix = "transformer"
185
+ supports_gradient_checkpointing = False
186
+
187
+ def _init_weights(self, module):
188
+ std = self.config.hidden_size ** -0.5
189
+ if isinstance(module, nn.Linear):
190
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
191
+ elif isinstance(module, nn.Embedding):
192
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
193
+
194
+
195
+ class GPTForCausalLM(GPTPreTrainedModel, GenerationMixin):
196
+ _tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"}
197
+
198
+ def __init__(self, config):
199
+ super().__init__(config)
200
+ self.config = config
201
+ self.transformer = nn.ModuleDict(dict(
202
+ wte=nn.Embedding(config.vocab_size, config.hidden_size),
203
+ h=nn.ModuleList([GPTBlock(config, i) for i in range(config.num_hidden_layers)]),
204
+ ln_f=RMSNorm(config.hidden_size, eps=config.rms_norm_eps),
205
+ ))
206
+ self.lm_head = CastedLinear(config.hidden_size, config.vocab_size, bias=False)
207
+ if config.tie_word_embeddings:
208
+ self.lm_head.weight = self.transformer["wte"].weight
209
+ if getattr(config, "use_place_embeddings", True):
210
+ self.place_embeddings = nn.Embedding(
211
+ config.place_vocab_size,
212
+ config.hidden_size,
213
+ padding_idx=0,
214
+ )
215
+ else:
216
+ self.place_embeddings = None
217
+ if getattr(config, "use_role_embeddings", True):
218
+ self.role_embeddings = nn.Embedding(
219
+ config.role_vocab_size,
220
+ config.hidden_size,
221
+ padding_idx=0,
222
+ )
223
+ else:
224
+ self.role_embeddings = None
225
+ self._freqs_cis_cache = None
226
+ self.post_init()
227
+ with torch.no_grad():
228
+ if self.place_embeddings is not None:
229
+ self.place_embeddings.weight[0].zero_()
230
+ if self.role_embeddings is not None:
231
+ self.role_embeddings.weight[0].zero_()
232
+ restore_fp32_params(self)
233
+
234
+ def _apply(self, fn):
235
+ module = super()._apply(fn)
236
+ restore_fp32_params(self)
237
+ return module
238
+
239
+ def get_input_embeddings(self):
240
+ return self.transformer["wte"]
241
+
242
+ def set_input_embeddings(self, value):
243
+ self.transformer["wte"] = value
244
+
245
+ def get_output_embeddings(self):
246
+ return self.lm_head
247
+
248
+ def set_output_embeddings(self, new_embeddings):
249
+ self.lm_head = new_embeddings
250
+
251
+ def embed_tokens(self, input_ids, *, place_ids=None, role_ids=None, **kwargs):
252
+ embeddings = self.transformer["wte"](input_ids)
253
+ if self.place_embeddings is not None:
254
+ if place_ids is None:
255
+ place_ids = torch.zeros_like(input_ids)
256
+ if place_ids.shape != input_ids.shape:
257
+ raise ValueError("place_ids must match input_ids shape")
258
+ embeddings = embeddings + self.place_embeddings(place_ids)
259
+ if self.role_embeddings is not None:
260
+ if role_ids is None:
261
+ role_ids = torch.zeros_like(input_ids)
262
+ if role_ids.shape != input_ids.shape:
263
+ raise ValueError("role_ids must match input_ids shape")
264
+ embeddings = embeddings + self.role_embeddings(role_ids)
265
+ return embeddings
266
+
267
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs):
268
+ if past_key_values is not None and past_key_values.get_seq_length() > 0:
269
+ input_ids = input_ids[:, -1:]
270
+ if kwargs.get("place_ids") is not None:
271
+ kwargs["place_ids"] = kwargs["place_ids"][:, -1:]
272
+ if kwargs.get("role_ids") is not None:
273
+ kwargs["role_ids"] = kwargs["role_ids"][:, -1:]
274
+ return {
275
+ "input_ids": input_ids,
276
+ "place_ids": kwargs.get("place_ids"),
277
+ "role_ids": kwargs.get("role_ids"),
278
+ "attention_mask": attention_mask,
279
+ "past_key_values": past_key_values,
280
+ "use_cache": True,
281
+ }
282
+
283
+ def _get_freqs_cis(self, seq_len, device):
284
+ cache = self._freqs_cis_cache
285
+ if cache is None or cache[0].device != device or cache[0].size(0) < seq_len:
286
+ cache = tuple(
287
+ tensor.to(device)
288
+ for tensor in precompute_rope_cos_sin(self.config.head_dim, seq_len, self.config.rope_theta)
289
+ )
290
+ if torch.is_inference_mode_enabled():
291
+ return cache[0][:seq_len], cache[1][:seq_len]
292
+ self._freqs_cis_cache = cache
293
+ return cache[0][:seq_len], cache[1][:seq_len]
294
+
295
+ def forward(
296
+ self,
297
+ input_ids,
298
+ attention_mask=None,
299
+ labels=None,
300
+ past_key_values: Optional[DynamicCache] = None,
301
+ use_cache=False,
302
+ **kwargs,
303
+ ):
304
+ B, T = input_ids.size()
305
+ if use_cache and past_key_values is None:
306
+ past_key_values = DynamicCache()
307
+
308
+ past_len = past_key_values.get_seq_length() if past_key_values is not None else 0
309
+ x = self.embed_tokens(input_ids, **kwargs)
310
+ cos, sin = self._get_freqs_cis(past_len + T, input_ids.device)
311
+ freqs_cis = (
312
+ cos[past_len:past_len + T],
313
+ sin[past_len:past_len + T],
314
+ )
315
+
316
+ for block in self.transformer["h"]:
317
+ x = block(x, freqs_cis, past_key_values if use_cache else None, use_cache, attention_mask=attention_mask)
318
+
319
+ x = self.transformer["ln_f"](x)
320
+ logits = self.lm_head(x)
321
+
322
+ loss = None
323
+ if labels is not None:
324
+ if getattr(self.config, "labels_are_shifted", False):
325
+ loss = F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), labels.reshape(-1))
326
+ else:
327
+ shift_logits = logits[..., :-1, :].contiguous()
328
+ shift_labels = labels[..., 1:].contiguous()
329
+ loss = F.cross_entropy(shift_logits.float().view(-1, shift_logits.size(-1)), shift_labels.reshape(-1))
330
+
331
+ return CausalLMOutputWithPast(
332
+ loss=loss,
333
+ logits=logits,
334
+ past_key_values=past_key_values if use_cache else None,
335
+ )
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a21f5910c898c5464320d3f01b15ccde1eb278073266b221e48e9ec15ccbe899
3
+ size 10930496
pretraining_curriculum.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": 1,
3
+ "transition_fraction": 0.1,
4
+ "stages": [
5
+ {
6
+ "name": "early",
7
+ "start": 0.0,
8
+ "end": 0.4,
9
+ "weights": {
10
+ "Ultra-FineWeb": 0.5,
11
+ "FineWeb-Edu": 0.38,
12
+ "FineMath": 0.05,
13
+ "Cosmopedia-v2": 0.05,
14
+ "UltraData-Math-L2-preview": 0.02
15
+ }
16
+ },
17
+ {
18
+ "name": "mid",
19
+ "start": 0.4,
20
+ "end": 0.8,
21
+ "weights": {
22
+ "Ultra-FineWeb": 0.12,
23
+ "FineWeb-Edu": 0.22,
24
+ "FineMath": 0.18,
25
+ "Cosmopedia-v2": 0.13,
26
+ "UltraData-Math-L2-preview": 0.12,
27
+ "Ultra-FineWeb-L3-en-QA-Synthetic": 0.05,
28
+ "Synthetic-Arithmetic": 0.18
29
+ }
30
+ },
31
+ {
32
+ "name": "late",
33
+ "start": 0.8,
34
+ "end": 1.0,
35
+ "weights": {
36
+ "Ultra-FineWeb": 0.105,
37
+ "FineWeb-Edu": 0.21,
38
+ "FineMath": 0.14,
39
+ "Cosmopedia-v2": 0.14,
40
+ "UltraData-Math-L2-preview": 0.105,
41
+ "Ultra-FineWeb-L3-en-QA-Synthetic": 0.2,
42
+ "Synthetic-Arithmetic": 0.1
43
+ }
44
+ }
45
+ ]
46
+ }
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ tokenizers
4
+ safetensors
5
+ tqdm
6
+ lm-eval
special_tokens_map.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|endoftext|>"
4
+ ],
5
+ "bos_token": "<|bos|>",
6
+ "eos_token": "<|eos|>",
7
+ "pad_token": "<|pad|>",
8
+ "unk_token": "<|unk|>"
9
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|endoftext|>"
4
+ ],
5
+ "bos_token": "<|bos|>",
6
+ "eos_token": "<|eos|>",
7
+ "model_max_length": 512,
8
+ "pad_token": "<|pad|>",
9
+ "tokenizer_class": "GPT2TokenizerFast",
10
+ "unk_token": "<|unk|>"
11
+ }
tokenizer_utils.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared construction and loading helpers for the project's tokenizer."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ import json
7
+ from pathlib import Path
8
+ import re
9
+ from typing import Any, Iterable
10
+
11
+
12
+ SPECIAL_TOKENS = [
13
+ "<|pad|>",
14
+ "<|bos|>",
15
+ "<|eos|>",
16
+ "<|unk|>",
17
+ "<|endoftext|>",
18
+ ]
19
+ EOT_ID = SPECIAL_TOKENS.index("<|endoftext|>")
20
+ ARITHMETIC_TOKENS = ("+", "-", "*", "/", "=", "(", ")")
21
+ MAX_PLACE_ID = 64
22
+ PLACE_OVERFLOW_ID = MAX_PLACE_ID + 1
23
+ PLACE_VOCAB_SIZE = PLACE_OVERFLOW_ID + 1
24
+ RESULT_ROLE_ID = 10
25
+ SPACE_ROLE_ID = 11
26
+ ROLE_VOCAB_SIZE = SPACE_ROLE_ID + 1
27
+ MAX_OPERAND_ROLES = 9
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class FusionEncoding:
32
+ ids: list[int]
33
+ place_ids: list[int]
34
+ role_ids: list[int]
35
+ tokens: list[str] = field(default_factory=list)
36
+
37
+ @property
38
+ def input_ids(self) -> list[int]:
39
+ return self.ids
40
+
41
+ def __len__(self) -> int:
42
+ return len(self.ids)
43
+
44
+ def __iter__(self):
45
+ return iter(self.ids)
46
+
47
+ def __post_init__(self) -> None:
48
+ if not (len(self.ids) == len(self.place_ids) == len(self.role_ids)):
49
+ raise ValueError("Fusion tokenizer streams must have equal length")
50
+
51
+
52
+ def build_tokenizer() -> Any:
53
+ """Build a byte-level BPE tokenizer with explicit lossless boundaries."""
54
+ from tokenizers import Regex, Tokenizer, decoders, models, pre_tokenizers
55
+
56
+ tokenizer = Tokenizer(models.BPE(unk_token="<|unk|>"))
57
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
58
+ [
59
+ pre_tokenizers.Split(
60
+ Regex(r"\s+|\d|[+\-*/=()]|[^\s\d+\-*/=()]+"),
61
+ behavior="isolated",
62
+ ),
63
+ pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False),
64
+ ]
65
+ )
66
+ tokenizer.decoder = decoders.ByteLevel()
67
+ return tokenizer
68
+
69
+
70
+ class FusionTokenizer:
71
+ """Runtime wrapper adding LSD-first digit streams to a trained BPE tokenizer."""
72
+
73
+ _digit_span_re = re.compile(r"\d+")
74
+
75
+ def __init__(self, tokenizer: Any):
76
+ self.tokenizer = tokenizer
77
+ self._digit_token_ids = frozenset(
78
+ token_id
79
+ for digit in "0123456789"
80
+ if (token_id := self.tokenizer.token_to_id(digit)) is not None
81
+ )
82
+ self._digit_id_to_text = {
83
+ int(self.tokenizer.token_to_id(digit)): digit
84
+ for digit in "0123456789"
85
+ if self.tokenizer.token_to_id(digit) is not None
86
+ }
87
+ self._equals_id = self.tokenizer.token_to_id("=")
88
+ self._special_token_ids = frozenset(
89
+ token_id
90
+ for token in SPECIAL_TOKENS
91
+ if (token_id := self.tokenizer.token_to_id(token)) is not None
92
+ )
93
+ if len(self._digit_token_ids) != 10:
94
+ raise ValueError("Tokenizer vocabulary must contain atomic digit tokens 0-9")
95
+ if self._equals_id is None:
96
+ raise ValueError("Tokenizer vocabulary must contain an atomic '=' token")
97
+
98
+ def __getattr__(self, name: str) -> Any:
99
+ return getattr(self.tokenizer, name)
100
+
101
+ @property
102
+ def digit_token_ids(self) -> frozenset[int]:
103
+ return self._digit_token_ids
104
+
105
+ @property
106
+ def special_token_ids(self) -> frozenset[int]:
107
+ return self._special_token_ids
108
+
109
+ def get_vocab_size(self, with_added_tokens: bool = True) -> int:
110
+ return int(self.tokenizer.get_vocab_size(with_added_tokens=with_added_tokens))
111
+
112
+ def get_vocab(self, with_added_tokens: bool = True) -> dict[str, int]:
113
+ return self.tokenizer.get_vocab(with_added_tokens=with_added_tokens)
114
+
115
+ def token_to_id(self, token: str) -> int | None:
116
+ return self.tokenizer.token_to_id(token)
117
+
118
+ def id_to_token(self, token_id: int) -> str | None:
119
+ return self.tokenizer.id_to_token(int(token_id))
120
+
121
+ @classmethod
122
+ def _reverse_digit_spans(cls, text: str) -> str:
123
+ return cls._digit_span_re.sub(lambda match: match.group(0)[::-1], text)
124
+
125
+ def _decode_token_piece(self, token_id: int) -> str:
126
+ return self.tokenizer.decode([int(token_id)], skip_special_tokens=False)
127
+
128
+ @staticmethod
129
+ def _is_equation_whitespace(piece: str) -> bool:
130
+ return bool(piece) and piece.isspace() and "\n" not in piece and "\r" not in piece
131
+
132
+ def _is_equation_piece(self, token_id: int, piece: str) -> bool:
133
+ if token_id in self._special_token_ids:
134
+ return False
135
+ if token_id in self._digit_token_ids:
136
+ return True
137
+ if self._is_equation_whitespace(piece):
138
+ return True
139
+ return len(piece) == 1 and piece in set(ARITHMETIC_TOKENS)
140
+
141
+ def _annotate_equation_span(
142
+ self,
143
+ ids: list[int],
144
+ pieces: list[str],
145
+ start: int,
146
+ end: int,
147
+ role_ids: list[int],
148
+ ) -> None:
149
+ equals_positions = [
150
+ index
151
+ for index in range(start, end)
152
+ if ids[index] == self._equals_id
153
+ ]
154
+ if len(equals_positions) != 1:
155
+ return
156
+ equals_position = equals_positions[0]
157
+
158
+ digit_runs: list[tuple[int, int]] = []
159
+ index = start
160
+ while index < end:
161
+ if ids[index] not in self._digit_token_ids:
162
+ index += 1
163
+ continue
164
+ run_start = index
165
+ while index < end and ids[index] in self._digit_token_ids:
166
+ index += 1
167
+ digit_runs.append((run_start, index))
168
+
169
+ operand_runs = [(a, b) for a, b in digit_runs if b <= equals_position]
170
+ result_runs = [(a, b) for a, b in digit_runs if a > equals_position]
171
+ if not operand_runs or not result_runs or len(operand_runs) > MAX_OPERAND_ROLES:
172
+ return
173
+
174
+ for index in range(start, end):
175
+ if self._is_equation_whitespace(pieces[index]):
176
+ role_ids[index] = SPACE_ROLE_ID
177
+
178
+ for role, (run_start, run_end) in enumerate(operand_runs, start=1):
179
+ for index in range(run_start, run_end):
180
+ role_ids[index] = role
181
+ for run_start, run_end in result_runs:
182
+ for index in range(run_start, run_end):
183
+ role_ids[index] = RESULT_ROLE_ID
184
+
185
+ def annotate_ids(self, ids: Iterable[int]) -> tuple[list[int], list[int]]:
186
+ input_ids = [int(token_id) for token_id in ids]
187
+ place_ids = [0] * len(input_ids)
188
+ role_ids = [0] * len(input_ids)
189
+ pieces = [self._decode_token_piece(token_id) for token_id in input_ids]
190
+
191
+ index = 0
192
+ while index < len(input_ids):
193
+ if input_ids[index] not in self._digit_token_ids:
194
+ index += 1
195
+ continue
196
+ run_start = index
197
+ while index < len(input_ids) and input_ids[index] in self._digit_token_ids:
198
+ offset = index - run_start + 1
199
+ place_ids[index] = min(offset, PLACE_OVERFLOW_ID)
200
+ index += 1
201
+
202
+ span_start: int | None = None
203
+ for index, (token_id, piece) in enumerate(zip(input_ids, pieces, strict=True)):
204
+ if self._is_equation_piece(token_id, piece):
205
+ if span_start is None:
206
+ span_start = index
207
+ continue
208
+ if span_start is not None:
209
+ self._annotate_equation_span(input_ids, pieces, span_start, index, role_ids)
210
+ span_start = None
211
+ if span_start is not None:
212
+ self._annotate_equation_span(input_ids, pieces, span_start, len(input_ids), role_ids)
213
+
214
+ return place_ids, role_ids
215
+
216
+ def encode(self, text: str, *args, **kwargs) -> FusionEncoding:
217
+ transformed = self._reverse_digit_spans(text)
218
+ encoding = self.tokenizer.encode(transformed, *args, **kwargs)
219
+ ids = [int(token_id) for token_id in encoding.ids]
220
+ place_ids, role_ids = self.annotate_ids(ids)
221
+ return FusionEncoding(
222
+ ids=ids,
223
+ place_ids=place_ids,
224
+ role_ids=role_ids,
225
+ tokens=list(getattr(encoding, "tokens", [])),
226
+ )
227
+
228
+ def encode_batch(self, texts: list[str], *args, **kwargs) -> list[FusionEncoding]:
229
+ return [self.encode(text, *args, **kwargs) for text in texts]
230
+
231
+ def decode(
232
+ self,
233
+ token_ids: Iterable[int],
234
+ skip_special_tokens: bool = True,
235
+ ) -> str:
236
+ pieces: list[str] = []
237
+ text_ids: list[int] = []
238
+ digit_buffer: list[str] = []
239
+
240
+ def flush_text() -> None:
241
+ if text_ids:
242
+ pieces.append(
243
+ self.tokenizer.decode(
244
+ text_ids,
245
+ skip_special_tokens=skip_special_tokens,
246
+ )
247
+ )
248
+ text_ids.clear()
249
+
250
+ def flush_digits() -> None:
251
+ if digit_buffer:
252
+ pieces.extend(reversed(digit_buffer))
253
+ digit_buffer.clear()
254
+
255
+ for raw_id in token_ids:
256
+ token_id = int(raw_id)
257
+ if token_id in self._digit_token_ids:
258
+ flush_text()
259
+ digit_buffer.append(self._digit_id_to_text[token_id])
260
+ continue
261
+
262
+ flush_digits()
263
+ text_ids.append(token_id)
264
+
265
+ flush_text()
266
+ flush_digits()
267
+ return "".join(pieces)
268
+
269
+
270
+ def build_trainer(vocab_size: int, min_frequency: int) -> Any:
271
+ from tokenizers import pre_tokenizers, trainers
272
+
273
+ return trainers.BpeTrainer(
274
+ vocab_size=vocab_size,
275
+ min_frequency=min_frequency,
276
+ special_tokens=SPECIAL_TOKENS,
277
+ initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
278
+ )
279
+
280
+
281
+ def tokenizer_files(tokenizer_dir: Path) -> tuple[Path, Path, Path]:
282
+ return (
283
+ tokenizer_dir / "tokenizer.json",
284
+ tokenizer_dir / "vocab.json",
285
+ tokenizer_dir / "merges.txt",
286
+ )
287
+
288
+
289
+ def validate_tokenizer(tokenizer_dir: Path) -> None:
290
+ tokenizer_json, vocab_path, merges_path = tokenizer_files(tokenizer_dir)
291
+ if not tokenizer_json.exists():
292
+ raise FileNotFoundError(
293
+ f"Missing {tokenizer_json}. Retrain with train_tokenizer.py so the "
294
+ "whitespace and digit boundary rules are preserved."
295
+ )
296
+ if vocab_path.exists():
297
+ with vocab_path.open("r", encoding="utf-8") as f:
298
+ vocab = json.load(f)
299
+ else:
300
+ with tokenizer_json.open("r", encoding="utf-8") as f:
301
+ tokenizer_data = json.load(f)
302
+ vocab = tokenizer_data.get("model", {}).get("vocab")
303
+ if not isinstance(vocab, dict):
304
+ raise FileNotFoundError(f"Missing vocab.json and no embedded vocab in {tokenizer_json}")
305
+
306
+ max_id = max(vocab.values())
307
+ if max_id > 65_535:
308
+ raise ValueError(f"Tokenizer max id {max_id} does not fit in uint16")
309
+ if vocab.get("<|endoftext|>") != EOT_ID:
310
+ raise ValueError(
311
+ f"Expected <|endoftext|> id {EOT_ID}, "
312
+ f"got {vocab.get('<|endoftext|>')}"
313
+ )
314
+ missing = [
315
+ token
316
+ for token in (*[str(value) for value in range(10)], *ARITHMETIC_TOKENS)
317
+ if token not in vocab
318
+ ]
319
+ if missing:
320
+ raise ValueError(f"Tokenizer missing required atomic tokens: {missing}")
321
+
322
+
323
+ def load_tokenizer(tokenizer_dir: Path) -> Any:
324
+ from tokenizers import Tokenizer
325
+
326
+ validate_tokenizer(tokenizer_dir)
327
+ tokenizer_json, _, _ = tokenizer_files(tokenizer_dir)
328
+ return FusionTokenizer(Tokenizer.from_file(str(tokenizer_json)))