N8Programs commited on
Commit
5db721a
·
verified ·
1 Parent(s): fc79e8b

Add model card and evaluation utilities

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ radar_chart.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,179 @@
1
  ---
2
  license: cc-by-sa-4.0
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: cc-by-sa-4.0
3
+ library_name: transformers
4
+ pipeline_tag: text-generation
5
+ tags:
6
+ - oeis
7
+ - integer-sequences
8
+ - qwen3
9
+ - causal-lm
10
+ - mlx
11
+ datasets:
12
+ - N8Programs/oeis-massive
13
  ---
14
+
15
+ # NextTerm-440M
16
+
17
+ [![Benchmark radar](radar_chart.png)](radar_chart.png)
18
+
19
+ ## Model Summary
20
+
21
+ NextTerm-440M is a 440M parameter causal transformer trained to continue integer
22
+ sequences. It uses a Qwen3 architecture with a compact 16-token digit
23
+ vocabulary: decimal digits, negative sign, comma separator, BOS, EOS, PAD, and one unused token.
24
+
25
+ The model was trained on an extended OEIS corpus that enhanced many OEIS sequences with additional terms from b-files (supplemental appendices provided with OEIS) and then further augmented the data w/ a variety of prefix-preserving transforms empirically selected via small pilot experiments. The model was trained for 14B tokens w/ preserved sequence prefixes rather than concatenating distinct documents (as this was found to improve performance in pilot experiments).
26
+
27
+
28
+ NextTerm-440M improves dramatically over NextTerm-47M on long-context sequence continuation (as it was trained w/ a context length of 4096), innate OEIS knowledge, and long-range in context learning. The 47M model, however, remains ahead on very short prefixes that require simple rule induction without much context, which may be due to the 440M model's training on longer contexts and more complex sequences.
29
+
30
+ The tokenizer accepts integer sequences formatted as comma-separated values, for
31
+ example:
32
+
33
+ ```text
34
+ 1,-2,3,-4,
35
+ ```
36
+
37
+ The tokenizer ignores characters other than digits, commas, and `-`. Digits are
38
+ tokenized individually, so there is no fixed integer-magnitude limit, but large
39
+ integers consume more context. The model was not trained on numbers with leading
40
+ zeros, so strings like `01,02,03,` should be treated as out of distribution.
41
+
42
+ ## Training Details
43
+
44
+ | Field | Value |
45
+ | --- | --- |
46
+ | Parameters | 440,500,224 |
47
+ | Architecture | Qwen3-style causal LM |
48
+ | Layers | 28 |
49
+ | Hidden size | 1024 |
50
+ | FFN size | 3072 |
51
+ | Attention heads | 16 |
52
+ | KV heads | 8 |
53
+ | Vocabulary size | 16 |
54
+ | Training tokens | 13,999,999,995 |
55
+ | Sequence length cap | 4096 training tokens per sequence |
56
+ | Batch mode | Length-bucketed sequence batches |
57
+ | Optimizer | Muon/AdamW hybrid |
58
+ | LR schedule | Linear warmup to `1e-2` for Muon, `1e-4` for AdamW, cosine decay to 0.1x, final cooldown to 0 |
59
+ | Training hardware | Single H100 |
60
+ | Export dtype | bfloat16 |
61
+
62
+ A classic Muon/AdamW hybrid was used: Muon for 2D weight matrices and AdamW for 1D parameters and embedding matrices.
63
+
64
+ The model was trained on the following files in the `N8Programs/oeis-massive` dataset, randomly mixed:
65
+
66
+ - `oeis_train_bfile_prefix4096.packed`
67
+ - `oeis_synth_aug0_inv_len_13245370099_seed0.packed`
68
+
69
+ ## Evaluation Results
70
+
71
+ ### Main Benchmarks
72
+
73
+ | Model | OEIS-Eval-Neo | Ryskina & Knight | M1 Competition 111 MAPE |
74
+ | --- | ---: | ---: | ---: |
75
+ | **NextTerm-440M** | **34.43%** | 52.63% | **17.6239** |
76
+ | NextTerm-47M | 29.49% | **70.18%** | 18.7621 |
77
+ | Qwen3-0.6B | 18.44% | 33.33% | 22.7984 |
78
+ | Qwen3-1.7B | 20.77% | 49.12% | 22.2411 |
79
+ | Qwen3-4B | 23.74% | 63.16% | 19.1731 |
80
+ | Qwen3-8B | 24.62% | 57.89% | 18.4027 |
81
+ | Qwen3-14B | 26.00% | 59.65% | 17.9837 |
82
+
83
+ OEIS-Eval-Neo is a decontaminated held-out OEIS next-term evaluation. M1
84
+ Competition 111 reports macro MAPE, where lower is better. Ryskina & Knight
85
+ (2021) is a 57-sequence next-term benchmark based on psychometrics and puzzles. Note that the 47M model's strong performance on Ryskina & Knight is indicative of its strength on short-prefix sequences and rule induction.
86
+
87
+ ### Polynomial Continuation
88
+
89
+ The polynomial continuation evaluation samples integer sequences from
90
+ polynomials of degree 1 through 4 and asks for the next term. Accuracy is exact
91
+ match across 200 samples for each prompt length `k`.
92
+
93
+ | Model | Arithmetic | Quadratic | Cubic | Quartic |
94
+ | --- | ---: | ---: | ---: | ---: |
95
+ | **NextTerm-440M** | 94.38% | **86.39%** | **75.20%** | **67.83%** |
96
+ | NextTerm-47M | 94.15% | 81.07% | 37.43% | 15.17% |
97
+ | Qwen3-0.6B | 90.31% | 8.72% | 0.30% | 0.02% |
98
+ | Qwen3-1.7B | 93.10% | 41.57% | 5.36% | 0.71% |
99
+ | Qwen3-4B | 93.90% | 77.26% | 28.18% | 5.98% |
100
+ | Qwen3-8B | **96.10%** | 80.59% | 32.93% | 7.95% |
101
+ | Qwen3-14B | 95.60% | 84.61% | 49.16% | 14.98% |
102
+
103
+ ## Usage
104
+
105
+ ### MLX
106
+
107
+ ```bash
108
+ mlx_lm.generate --model N8Programs/NextTerm-440M --prompt "1,2,3,"
109
+ ```
110
+
111
+ ### Hugging Face Transformers
112
+
113
+ ```python
114
+ import torch
115
+ from transformers import AutoModelForCausalLM, AutoTokenizer
116
+
117
+ model_name = "N8Programs/NextTerm-440M"
118
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
119
+ model = AutoModelForCausalLM.from_pretrained(
120
+ model_name,
121
+ torch_dtype=torch.bfloat16,
122
+ device_map="auto",
123
+ )
124
+
125
+ prompt = "1,2,3,4,5,"
126
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
127
+
128
+ outputs = model.generate(
129
+ **inputs,
130
+ max_new_tokens=64,
131
+ do_sample=False,
132
+ eos_token_id=[tokenizer.convert_tokens_to_ids(","), tokenizer.eos_token_id],
133
+ pad_token_id=tokenizer.pad_token_id,
134
+ )
135
+
136
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
137
+ ```
138
+
139
+ For strict next-term evaluation, stop generation on *comma or EOS* and parse the
140
+ text before the first comma as the predicted integer.
141
+
142
+ ## Reproducibility
143
+
144
+ This repository contains the local evaluation scripts and artifacts used for the
145
+ results above, including:
146
+
147
+ - `oeis_eval_mlx_neo.py` for OEIS-Eval-Neo with MLX batch generation.
148
+ - `arithmetic_eval.py` for arithmetic/quadratic/cubic/quartic continuation.
149
+ - `eval_m1_competition_mape_mlx.py` for M1 Competition 111 MAPE.
150
+ - `eval_results.txt` for the compact result table.
151
+
152
+
153
+ The last three training checkpoints are available separately at
154
+ [N8Programs/NextTerm-440M-Checkpoints](https://huggingface.co/N8Programs/NextTerm-440M-Checkpoints).
155
+ The released `final_latest` checkpoint was trained for 14B tokens. Additionally, the checkpoint corresponding to the best val loss is available as well (although it is not included in the main results table as it was inferior on downstream eval performance).
156
+
157
+ The `.packed` files used for training are binary files containing the tokenized and augmented OEIS data - w/ tokens encoded as nibbles. A dedicate decoder is provided in this repo as `decode_packed_oeis.py`.
158
+
159
+ ## Citation
160
+
161
+ ```bibtex
162
+ @misc{nextterm440m2026,
163
+ author = {Nathan Breslow},
164
+ title = {NextTerm-440M: A Pretrained Transformer for Integer Sequence Prediction},
165
+ year = {2026},
166
+ publisher = {Hugging Face},
167
+ howpublished = {\url{https://huggingface.co/N8Programs/NextTerm-440M}},
168
+ note = {440.5M parameter model trained on augmented OEIS data}
169
+ }
170
+ ```
171
+
172
+ ## Attribution
173
+
174
+ This model and dataset were trained and created using data from the
175
+ **On-Line Encyclopedia of Integer Sequences (OEIS)**.
176
+
177
+ - Source: https://oeis.org/
178
+ - License: Creative Commons Attribution-ShareAlike 4.0 (CC BY-SA 4.0)
179
+ - OEIS End-User License Agreement: https://oeis.org/wiki/The_OEIS_End-User_License_Agreement
arithmetic_eval.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluate the NextTerm model on synthetic polynomial sequences (arithmetic through
3
+ quartic) with varying amounts of prior context.
4
+
5
+ For each prior term count (default: 2-25 for linear, 3-25 for quadratic,
6
+ 4-25 for cubic, 5-25 for quartic), we sample sequences from
7
+ ax^4 + bx^3 + cx^2 + dx + e with coefficient ranges:
8
+ e in [0, 99999]
9
+ d in [0, 9999]
10
+ c in [0, 999]
11
+ b in [0, 99]
12
+ a in [0, 9]
13
+
14
+ Higher-order coefficients are set to 0 for lower-degree tests (e.g., arithmetic
15
+ sequences only use d and e).
16
+ """
17
+
18
+ import argparse
19
+ import csv
20
+ import random
21
+ from dataclasses import dataclass
22
+ from typing import Dict, List, Tuple
23
+
24
+ from mlx_lm import load
25
+ from mlx_lm.generate import BatchGenerator
26
+ from mlx_lm.tuner.utils import print_trainable_parameters
27
+ from tqdm import tqdm
28
+
29
+ # Step 335000
30
+ MODEL_NAME = "NextTerm-47M"
31
+ MAX_NEW_TOKENS = 20
32
+ DEFAULT_MAX_TERMS = 25
33
+ SAMPLES_PER_K = 200
34
+ RANDOM_SEED = 0
35
+
36
+ COEFF_MAX = {"a": 9, "b": 99, "c": 999, "d": 9999, "e": 99999}
37
+
38
+
39
+ @dataclass(frozen=True)
40
+ class PolynomialTask:
41
+ name: str
42
+ degree: int
43
+ min_terms: int
44
+
45
+
46
+ TASKS = (
47
+ PolynomialTask("arithmetic", 1, 2),
48
+ PolynomialTask("quadratic", 2, 3),
49
+ PolynomialTask("cubic", 3, 4),
50
+ PolynomialTask("quartic", 4, 5),
51
+ )
52
+
53
+
54
+ def parse_generated(text: str):
55
+ if "," in text:
56
+ text = text.split(",")[0]
57
+ try:
58
+ return int(text)
59
+ except ValueError:
60
+ print(f"Could not parse generated text: {text!r}")
61
+ return None
62
+
63
+
64
+ def polynomial_value(coeffs: Tuple[int, int, int, int, int], x: int) -> int:
65
+ a, b, c, d, e = coeffs
66
+ return (
67
+ a * (x**4)
68
+ + b * (x**3)
69
+ + c * (x**2)
70
+ + d * x
71
+ + e
72
+ )
73
+
74
+
75
+ def sample_coefficients(
76
+ degree: int, rng: random.Random
77
+ ) -> Tuple[int, int, int, int, int]:
78
+ a = rng.randint(0, COEFF_MAX["a"]) if degree >= 4 else 0
79
+ b = rng.randint(0, COEFF_MAX["b"]) if degree >= 3 else 0
80
+ c = rng.randint(0, COEFF_MAX["c"]) if degree >= 2 else 0
81
+ d = rng.randint(0, COEFF_MAX["d"]) if degree >= 1 else 0
82
+ e = rng.randint(0, COEFF_MAX["e"])
83
+ return a, b, c, d, e
84
+
85
+
86
+ def generate_polynomial_sequences(
87
+ task: PolynomialTask, max_terms: int, samples_per_k: int, rng: random.Random
88
+ ) -> List[Dict[str, int | List[int] | str]]:
89
+ sequences = []
90
+ for k in range(task.min_terms, max_terms + 1):
91
+ for _ in range(samples_per_k):
92
+ coeffs = sample_coefficients(task.degree, rng)
93
+ seq = [polynomial_value(coeffs, i) for i in range(k + 1)]
94
+ sequences.append(
95
+ {"k": k, "prompt": seq[:-1], "answer": seq[-1], "task": task.name}
96
+ )
97
+ return sequences
98
+
99
+
100
+ def build_prompts(
101
+ sequences: List[Dict[str, int | List[int] | str]], tokenizer
102
+ ) -> Tuple[List[List[int]], List[Dict[str, int | List[int] | str]]]:
103
+ prompts = []
104
+ metadata = []
105
+ for record in sequences:
106
+ prompt_text = ",".join(str(x) for x in record["prompt"]) + ","
107
+ prompts.append(tokenizer.encode(prompt_text))
108
+ metadata.append(record)
109
+ return prompts, metadata
110
+
111
+
112
+ def evaluate_task(
113
+ task: PolynomialTask,
114
+ args,
115
+ model,
116
+ tokenizer,
117
+ sep_token: int,
118
+ ) -> Tuple[Dict[int, Dict[str, int]], List[Dict[str, object]]]:
119
+ rng = random.Random(args.seed + task.degree)
120
+ sequences = generate_polynomial_sequences(
121
+ task, args.max_terms, args.samples_per_k, rng
122
+ )
123
+ if not sequences:
124
+ print(
125
+ f"Skipping {task.name}: min_terms {task.min_terms} exceeds max_terms "
126
+ f"{args.max_terms}"
127
+ )
128
+ return {}, []
129
+
130
+ print(
131
+ f"\nEvaluating {task.name} sequences (degree {task.degree}) with "
132
+ f"{len(sequences)} samples ({args.samples_per_k} per prior-term count)"
133
+ )
134
+
135
+ prompts, metadata = build_prompts(sequences, tokenizer)
136
+
137
+ gen = BatchGenerator(model, stop_tokens=[[sep_token], [tokenizer.eos_token_id]])
138
+ uids = gen.insert(prompts, [args.max_new_tokens] * len(prompts))
139
+
140
+ meta_by_uid = dict(zip(uids, metadata))
141
+ results = {uid: [] for uid in uids}
142
+ finished = set()
143
+ with tqdm(total=len(uids), desc=f"{task.name.title()}") as pbar:
144
+ while True:
145
+ responses = gen.next()
146
+ if not isinstance(responses, tuple) or len(responses) != 2:
147
+ raise RuntimeError(
148
+ "Unexpected mlx_lm BatchGenerator.next() API. "
149
+ "Update your mlx_lm version"
150
+ )
151
+ prompt_responses, generation_responses = responses
152
+ if not prompt_responses and not generation_responses:
153
+ break
154
+ if not generation_responses:
155
+ continue
156
+ for r in generation_responses:
157
+ results[r.uid].append(r.token)
158
+ if r.finish_reason == "stop" and r.uid not in finished:
159
+ finished.add(r.uid)
160
+ pbar.update(1)
161
+ gen.close()
162
+
163
+ stats = {
164
+ k: {"correct": 0, "parsed": 0, "total": 0}
165
+ for k in range(task.min_terms, args.max_terms + 1)
166
+ }
167
+
168
+ for uid in uids:
169
+ tokens = results[uid]
170
+ text = tokenizer.decode(tokens[:-1]) if tokens else ""
171
+ prediction = parse_generated(text)
172
+
173
+ meta = meta_by_uid[uid]
174
+ k = meta["k"]
175
+ answer = meta["answer"]
176
+
177
+ stats[k]["total"] += 1
178
+ if prediction is not None:
179
+ stats[k]["parsed"] += 1
180
+ if prediction == answer:
181
+ stats[k]["correct"] += 1
182
+
183
+ print("Accuracy by prior term count:")
184
+ csv_rows: List[Dict[str, object]] = []
185
+ for k in range(task.min_terms, args.max_terms + 1):
186
+ s = stats[k]
187
+ accuracy = s["correct"] / s["total"] if s["total"] else 0.0
188
+ print(
189
+ f"{k} terms: parsed {s['parsed']}/{s['total']} "
190
+ f"accuracy {s['correct']}/{s['total']} = {accuracy:.4f}"
191
+ )
192
+ csv_rows.append(
193
+ {
194
+ "task": task.name,
195
+ "degree": task.degree,
196
+ "k": k,
197
+ "parsed": s["parsed"],
198
+ "total": s["total"],
199
+ "correct": s["correct"],
200
+ "accuracy": accuracy,
201
+ }
202
+ )
203
+
204
+ overall_total = sum(s["total"] for s in stats.values())
205
+ overall_parsed = sum(s["parsed"] for s in stats.values())
206
+ overall_correct = sum(s["correct"] for s in stats.values())
207
+ overall_accuracy = overall_correct / overall_total if overall_total else 0.0
208
+ print(
209
+ f"Overall {task.name}: parsed {overall_parsed}/{overall_total} "
210
+ f"accuracy {overall_correct}/{overall_total} = {overall_accuracy:.4f}"
211
+ )
212
+
213
+ return stats, csv_rows
214
+
215
+
216
+ def write_csv(rows: List[Dict[str, object]], path: str) -> None:
217
+ if not rows:
218
+ print("No results to write to CSV.")
219
+ return
220
+
221
+ fieldnames = ["task", "degree", "k", "parsed", "total", "correct", "accuracy"]
222
+ with open(path, "w", newline="") as f:
223
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
224
+ writer.writeheader()
225
+ writer.writerows(rows)
226
+ print(f"Wrote results to {path}")
227
+
228
+
229
+ def evaluate(args):
230
+ random.seed(args.seed)
231
+
232
+ model, tokenizer = load(args.model_name)
233
+
234
+ print_trainable_parameters(model)
235
+
236
+ sep_token = tokenizer.encode("1,")[-1]
237
+ all_rows: List[Dict[str, object]] = []
238
+ for task in TASKS:
239
+ _, rows = evaluate_task(task, args, model, tokenizer, sep_token)
240
+ all_rows.extend(rows)
241
+
242
+ if args.csv_path:
243
+ write_csv(all_rows, args.csv_path)
244
+
245
+
246
+ def parse_args():
247
+ parser = argparse.ArgumentParser()
248
+ parser.add_argument("--model_name", default=MODEL_NAME)
249
+ parser.add_argument("--max_new_tokens", type=int, default=MAX_NEW_TOKENS)
250
+ parser.add_argument("--max_terms", type=int, default=DEFAULT_MAX_TERMS)
251
+ parser.add_argument("--samples_per_k", type=int, default=SAMPLES_PER_K)
252
+ parser.add_argument("--seed", type=int, default=RANDOM_SEED)
253
+ parser.add_argument("--csv_path", type=str, default=None)
254
+ return parser.parse_args()
255
+
256
+
257
+ def main():
258
+ args = parse_args()
259
+ evaluate(args)
260
+
261
+
262
+ if __name__ == "__main__":
263
+ main()
decode_packed_oeis.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Small reference decoder for bigOEIS `.packed` files.
2
+
3
+ On disk, each token is stored in a 4-bit nibble:
4
+
5
+ 0..9 decimal digits
6
+ 10 term separator, i.e. comma
7
+ 11 negative sign
8
+ 14 final padding nibble, if the file has an odd nibble count
9
+ 15 sequence delimiter / EOS
10
+
11
+ Note that the packed disk codes are intentionally compact and are not exactly
12
+ the model vocabulary ids: the model uses NEG=10 and SEP=11. Use
13
+ `iter_model_token_rows()` when you want rows in model-token-id space.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import argparse
19
+ import json
20
+ from pathlib import Path
21
+ from typing import Iterator
22
+
23
+
24
+ PACKED_SEP = 10
25
+ PACKED_NEG = 11
26
+ PACKED_PAD = 14
27
+ PACKED_DELIM = 15
28
+
29
+ MODEL_NEG = 10
30
+ MODEL_SEP = 11
31
+ MODEL_BOS = 12
32
+ MODEL_EOS = 13
33
+
34
+
35
+ def iter_packed_nibbles(path: str | Path, chunk_size: int = 8 * 1024 * 1024) -> Iterator[int]:
36
+ """Yield high nibble then low nibble for every byte in `path`."""
37
+ with Path(path).open("rb") as f:
38
+ while True:
39
+ chunk = f.read(chunk_size)
40
+ if not chunk:
41
+ return
42
+ for byte in chunk:
43
+ yield byte >> 4
44
+ yield byte & 0x0F
45
+
46
+
47
+ def iter_model_token_rows(
48
+ path: str | Path,
49
+ *,
50
+ include_bos_eos: bool = True,
51
+ max_rows: int | None = None,
52
+ strict: bool = True,
53
+ ) -> Iterator[list[int]]:
54
+ """Yield each packed sequence as model-token ids.
55
+
56
+ By default rows include BOS/EOS, matching `tokenize_utils.tokenize_sequence`.
57
+ Set `include_bos_eos=False` to get only the content tokens.
58
+ """
59
+ row: list[int] = [MODEL_BOS] if include_bos_eos else []
60
+ yielded = 0
61
+ seen_pad = False
62
+
63
+ for nib in iter_packed_nibbles(path):
64
+ if nib == PACKED_PAD:
65
+ seen_pad = True
66
+ continue
67
+ if seen_pad:
68
+ if strict:
69
+ raise ValueError("Found non-pad nibble after final packed padding.")
70
+ continue
71
+
72
+ if 0 <= nib <= 9:
73
+ row.append(nib)
74
+ elif nib == PACKED_SEP:
75
+ row.append(MODEL_SEP)
76
+ elif nib == PACKED_NEG:
77
+ row.append(MODEL_NEG)
78
+ elif nib == PACKED_DELIM:
79
+ if include_bos_eos:
80
+ row.append(MODEL_EOS)
81
+ yield row
82
+ yielded += 1
83
+ if max_rows is not None and yielded >= max_rows:
84
+ return
85
+ row = [MODEL_BOS] if include_bos_eos else []
86
+ else:
87
+ if strict:
88
+ raise ValueError(f"Invalid packed nibble: {nib}")
89
+
90
+ empty = [MODEL_BOS] if include_bos_eos else []
91
+ if strict and row != empty:
92
+ raise ValueError("Packed file ended with an unterminated sequence.")
93
+
94
+
95
+ def iter_integer_sequences(
96
+ path: str | Path,
97
+ *,
98
+ as_ints: bool = False,
99
+ max_rows: int | None = None,
100
+ strict: bool = True,
101
+ ) -> Iterator[list[int] | list[str]]:
102
+ """Yield decoded OEIS rows.
103
+
104
+ Values are strings by default so enormous integers round-trip exactly
105
+ through JSON. Pass `as_ints=True` if Python integers are more convenient.
106
+ """
107
+ terms: list[str] = []
108
+ chars: list[str] = []
109
+ yielded = 0
110
+ seen_pad = False
111
+
112
+ def finish_term() -> None:
113
+ if chars:
114
+ terms.append("".join(chars))
115
+ chars.clear()
116
+ elif strict:
117
+ raise ValueError("Encountered an empty term in packed sequence.")
118
+
119
+ for nib in iter_packed_nibbles(path):
120
+ if nib == PACKED_PAD:
121
+ seen_pad = True
122
+ continue
123
+ if seen_pad:
124
+ if strict:
125
+ raise ValueError("Found non-pad nibble after final packed padding.")
126
+ continue
127
+
128
+ if 0 <= nib <= 9:
129
+ chars.append(str(nib))
130
+ elif nib == PACKED_NEG:
131
+ if strict and chars:
132
+ raise ValueError("Found a negative sign after term digits had started.")
133
+ chars.append("-")
134
+ elif nib == PACKED_SEP:
135
+ finish_term()
136
+ elif nib == PACKED_DELIM:
137
+ if chars:
138
+ finish_term()
139
+ row = [int(term) for term in terms] if as_ints else list(terms)
140
+ yield row
141
+ yielded += 1
142
+ if max_rows is not None and yielded >= max_rows:
143
+ return
144
+ terms.clear()
145
+ else:
146
+ if strict:
147
+ raise ValueError(f"Invalid packed nibble: {nib}")
148
+
149
+ if strict and (terms or chars):
150
+ raise ValueError("Packed file ended with an unterminated sequence.")
151
+
152
+
153
+ def _main() -> None:
154
+ parser = argparse.ArgumentParser(description=__doc__)
155
+ parser.add_argument("packed_file", help="Path to a .packed file")
156
+ parser.add_argument("-n", "--max-rows", type=int, default=5)
157
+ parser.add_argument("--tokens", action="store_true", help="Print model-token rows instead of decoded terms")
158
+ parser.add_argument("--content-only", action="store_true", help="Omit BOS/EOS when printing token rows")
159
+ parser.add_argument("--ints", action="store_true", help="Emit decoded terms as JSON numbers instead of strings")
160
+ parser.add_argument("--no-strict", action="store_true", help="Ignore invalid/trailing data instead of raising")
161
+ args = parser.parse_args()
162
+
163
+ strict = not args.no_strict
164
+ if args.tokens:
165
+ rows = iter_model_token_rows(
166
+ args.packed_file,
167
+ include_bos_eos=not args.content_only,
168
+ max_rows=args.max_rows,
169
+ strict=strict,
170
+ )
171
+ for row in rows:
172
+ print(json.dumps({"tokens": row}, separators=(",", ":")))
173
+ else:
174
+ rows = iter_integer_sequences(
175
+ args.packed_file,
176
+ as_ints=args.ints,
177
+ max_rows=args.max_rows,
178
+ strict=strict,
179
+ )
180
+ for row in rows:
181
+ print(json.dumps({"seq": row}, separators=(",", ":")))
182
+
183
+
184
+ if __name__ == "__main__":
185
+ _main()
eval_m1_competition_mape_mlx.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Evaluate NextTerm-style MLX models on the canonical 111-series M1 set."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import gc
8
+ import json
9
+ import math
10
+ import time
11
+ from dataclasses import dataclass
12
+ from decimal import Decimal
13
+ from pathlib import Path
14
+ from statistics import mean, median
15
+
16
+ import mlx.core as mx
17
+ from mlx_lm import load
18
+ from mlx_lm.generate import BatchGenerator
19
+ from tqdm import tqdm
20
+
21
+ from eval_m1_monthly_mape_mlx import (
22
+ SuppressTokenLogits,
23
+ context_scale,
24
+ mape_for_predictions,
25
+ parse_decimal,
26
+ parse_generated_terms,
27
+ scale_decimal_to_int,
28
+ )
29
+
30
+
31
+ DATA_PATH = Path("m1_competition_111.jsonl")
32
+ MODEL_PATH = Path("NextTerm-440M")
33
+ OUTPUT_PATH = Path("m1_eval_results/m1_competition111_nextterm440m_greedy_per_series.jsonl")
34
+ SUMMARY_PATH = Path("m1_eval_results/m1_competition111_nextterm440m_greedy_summary.json")
35
+
36
+
37
+ @dataclass
38
+ class SeriesRecord:
39
+ row_index: int
40
+ series_name: str
41
+ original_id: str
42
+ period: str
43
+ frequency: int
44
+ type: str
45
+ description: str
46
+ horizon: int
47
+ raw_context: list[str]
48
+ raw_target: list[str]
49
+ context_values: list[Decimal]
50
+ target_values: list[Decimal]
51
+ naive2_forecast: list[Decimal]
52
+ naive2_mape: float
53
+ scale: int
54
+ scaled_context: list[int]
55
+
56
+
57
+ def load_m1_competition(path: Path) -> list[SeriesRecord]:
58
+ records: list[SeriesRecord] = []
59
+ with path.open("r", encoding="utf-8") as f:
60
+ for line in f:
61
+ if not line.strip():
62
+ continue
63
+ row = json.loads(line)
64
+ raw_context = [str(x) for x in row["context"]]
65
+ raw_target = [str(x) for x in row["target"]]
66
+ context_values = [parse_decimal(x) for x in raw_context]
67
+ target_values = [parse_decimal(x) for x in raw_target]
68
+ scale = context_scale(raw_context)
69
+ scaled_context = [scale_decimal_to_int(v, scale) for v in context_values]
70
+ records.append(
71
+ SeriesRecord(
72
+ row_index=int(row["row_index"]),
73
+ series_name=str(row["series_name"]),
74
+ original_id=str(row["original_id"]),
75
+ period=str(row["period"]),
76
+ frequency=int(row["frequency"]),
77
+ type=str(row["type"]),
78
+ description=str(row["description"]),
79
+ horizon=int(row["horizon"]),
80
+ raw_context=raw_context,
81
+ raw_target=raw_target,
82
+ context_values=context_values,
83
+ target_values=target_values,
84
+ naive2_forecast=[parse_decimal(str(x)) for x in row["naive2_forecast"]],
85
+ naive2_mape=float(row["naive2_mape"]),
86
+ scale=scale,
87
+ scaled_context=scaled_context,
88
+ )
89
+ )
90
+ return records
91
+
92
+
93
+ def decimal_or_none(value: Decimal | None) -> str | None:
94
+ return str(value) if value is not None else None
95
+
96
+
97
+ def load_completed(path: Path) -> dict[int, dict]:
98
+ completed: dict[int, dict] = {}
99
+ if not path.exists():
100
+ return completed
101
+ with path.open("r", encoding="utf-8") as f:
102
+ for line in f:
103
+ if not line.strip():
104
+ continue
105
+ record = json.loads(line)
106
+ completed[int(record["row_index"])] = record
107
+ return completed
108
+
109
+
110
+ def aggregate_summary(records: list[dict]) -> dict:
111
+ def collect(key: str, rows: list[dict] = records) -> list[float]:
112
+ vals: list[float] = []
113
+ for r in rows:
114
+ value = r.get(key)
115
+ if value is not None and math.isfinite(float(value)):
116
+ vals.append(float(value))
117
+ return vals
118
+
119
+ model_series = collect("mape")
120
+ naive2_series = collect("naive2_mape")
121
+ all_model_apes = [
122
+ float(x)
123
+ for r in records
124
+ for x in (r.get("apes") or [])
125
+ if x is not None and math.isfinite(float(x))
126
+ ]
127
+ all_naive2_apes = [
128
+ float(x)
129
+ for r in records
130
+ for x in (r.get("naive2_apes") or [])
131
+ if x is not None and math.isfinite(float(x))
132
+ ]
133
+
134
+ periods = sorted({str(r["period"]) for r in records})
135
+ by_period = {}
136
+ for period in periods:
137
+ rows = [r for r in records if r["period"] == period]
138
+ period_model = collect("mape", rows)
139
+ period_naive2 = collect("naive2_mape", rows)
140
+ period_model_apes = [
141
+ float(x)
142
+ for r in rows
143
+ for x in (r.get("apes") or [])
144
+ if x is not None and math.isfinite(float(x))
145
+ ]
146
+ period_naive2_apes = [
147
+ float(x)
148
+ for r in rows
149
+ for x in (r.get("naive2_apes") or [])
150
+ if x is not None and math.isfinite(float(x))
151
+ ]
152
+ by_period[period] = {
153
+ "series_count": len(rows),
154
+ "model_macro_mape": mean(period_model) if period_model else None,
155
+ "model_point_mape": mean(period_model_apes) if period_model_apes else None,
156
+ "naive2_macro_mape": mean(period_naive2) if period_naive2 else None,
157
+ "naive2_point_mape": mean(period_naive2_apes) if period_naive2_apes else None,
158
+ "parsed_full_series": sum(
159
+ 1 for r in rows if int(r.get("parsed_terms", 0)) >= int(r.get("horizon", 0))
160
+ ),
161
+ }
162
+
163
+ horizon_counts: dict[int, int] = {}
164
+ for r in records:
165
+ h = int(r["horizon"])
166
+ horizon_counts[h] = horizon_counts.get(h, 0) + 1
167
+
168
+ return {
169
+ "series_count": len(records),
170
+ "horizon_counts": dict(sorted(horizon_counts.items())),
171
+ "model_macro_mape": mean(model_series) if model_series else None,
172
+ "model_median_series_mape": median(model_series) if model_series else None,
173
+ "model_point_mape": mean(all_model_apes) if all_model_apes else None,
174
+ "naive2_macro_mape": mean(naive2_series) if naive2_series else None,
175
+ "naive2_point_mape": mean(all_naive2_apes) if all_naive2_apes else None,
176
+ "parsed_full_series": sum(
177
+ 1 for r in records if int(r.get("parsed_terms", 0)) >= int(r.get("horizon", 0))
178
+ ),
179
+ "total_missing_terms": sum(int(r.get("missing_terms", 0)) for r in records),
180
+ "malformed_series": sum(1 for r in records if r.get("malformed", False)),
181
+ "by_period": by_period,
182
+ }
183
+
184
+
185
+ def parse_args() -> argparse.Namespace:
186
+ parser = argparse.ArgumentParser()
187
+ parser.add_argument("--data-path", type=Path, default=DATA_PATH)
188
+ parser.add_argument("--model", type=Path, default=MODEL_PATH)
189
+ parser.add_argument("--output", type=Path, default=OUTPUT_PATH)
190
+ parser.add_argument("--summary-output", type=Path, default=SUMMARY_PATH)
191
+ parser.add_argument("--batch-size", type=int, default=64)
192
+ parser.add_argument("--max-new-tokens", type=int, default=1024)
193
+ parser.add_argument("--max-context-tokens", type=int, default=2048)
194
+ parser.add_argument("--max-series", type=int, default=0)
195
+ parser.add_argument(
196
+ "--allow-eos-before-horizon",
197
+ action="store_true",
198
+ help="Do not suppress EOS/pad while waiting for all forecast terms.",
199
+ )
200
+ parser.add_argument(
201
+ "--missing-policy",
202
+ choices=["zero", "last_context", "skip"],
203
+ default="zero",
204
+ )
205
+ parser.add_argument("--overwrite", action="store_true")
206
+ return parser.parse_args()
207
+
208
+
209
+ def main() -> None:
210
+ args = parse_args()
211
+ started = time.perf_counter()
212
+ series = load_m1_competition(args.data_path)
213
+ if args.max_series > 0:
214
+ series = series[: args.max_series]
215
+ print(f"Loaded {len(series)} canonical M1 competition series from {args.data_path}")
216
+ print("Horizons:", {h: sum(1 for s in series if s.horizon == h) for h in sorted({s.horizon for s in series})})
217
+
218
+ model, tokenizer = load(str(args.model))
219
+ print(f"Loaded model: {args.model}")
220
+
221
+ prompts_text = [",".join(str(x) for x in s.scaled_context) + "," for s in series]
222
+ prompts = [tokenizer.encode(text) for text in prompts_text]
223
+ sep_token = tokenizer.encode("1,")[-1]
224
+ eval_indices = [
225
+ i
226
+ for i, prompt in enumerate(prompts)
227
+ if args.max_context_tokens <= 0 or len(prompt) < args.max_context_tokens
228
+ ]
229
+ skipped_long = len(prompts) - len(eval_indices)
230
+ eval_indices = sorted(eval_indices, key=lambda i: len(prompts[i]))
231
+ print(
232
+ f"Evaluating {len(eval_indices)} series; skipped_long={skipped_long}; "
233
+ f"batch_size={args.batch_size}; max_new_tokens={args.max_new_tokens}; sep_token={sep_token}"
234
+ )
235
+
236
+ args.output.parent.mkdir(parents=True, exist_ok=True)
237
+ args.summary_output.parent.mkdir(parents=True, exist_ok=True)
238
+ if args.overwrite and args.output.exists():
239
+ args.output.unlink()
240
+ completed = load_completed(args.output)
241
+ if completed:
242
+ print(f"Resuming from {args.output}: {len(completed)} series already done")
243
+ todo_indices = [i for i in eval_indices if i not in completed]
244
+
245
+ stop_tokens = []
246
+ suppress_token_ids: list[int] = []
247
+ if getattr(tokenizer, "eos_token_id", None) is not None:
248
+ suppress_token_ids.append(tokenizer.eos_token_id)
249
+ if args.allow_eos_before_horizon:
250
+ stop_tokens.append([tokenizer.eos_token_id])
251
+ if getattr(tokenizer, "pad_token_id", None) is not None:
252
+ suppress_token_ids.append(tokenizer.pad_token_id)
253
+ if args.allow_eos_before_horizon:
254
+ stop_tokens.append([tokenizer.pad_token_id])
255
+ suppress_processor = None if args.allow_eos_before_horizon else SuppressTokenLogits(suppress_token_ids)
256
+ gen = BatchGenerator(
257
+ model,
258
+ stop_tokens=stop_tokens or None,
259
+ completion_batch_size=args.batch_size,
260
+ prefill_batch_size=args.batch_size,
261
+ )
262
+ uid_to_idx: dict[int, int] = {}
263
+ generated_tokens: dict[int, list[int]] = {}
264
+ separator_counts: dict[int, int] = {}
265
+ if todo_indices:
266
+ logits_processors = (
267
+ [[suppress_processor] for _ in todo_indices]
268
+ if suppress_processor is not None
269
+ else None
270
+ )
271
+ uids = gen.insert(
272
+ [prompts[i] for i in todo_indices],
273
+ [args.max_new_tokens] * len(todo_indices),
274
+ logits_processors=logits_processors,
275
+ )
276
+ uid_to_idx = {uid: idx for uid, idx in zip(uids, todo_indices)}
277
+ generated_tokens = {uid: [] for uid in uids}
278
+ separator_counts = {uid: 0 for uid in uids}
279
+ finished: set[int] = set()
280
+
281
+ try:
282
+ with args.output.open("a", encoding="utf-8") as out:
283
+ with tqdm(total=len(todo_indices), desc="Generating") as progress:
284
+
285
+ def finalize_uid(uid: int, finish_reason: str) -> None:
286
+ if uid in finished:
287
+ return
288
+ finished.add(uid)
289
+ idx = uid_to_idx[uid]
290
+ s = series[idx]
291
+ generated_text = tokenizer.decode(generated_tokens[uid])
292
+ scaled_terms, malformed = parse_generated_terms(generated_text, s.horizon)
293
+ predictions: list[Decimal | None] = [
294
+ Decimal(term) / Decimal(s.scale) for term in scaled_terms
295
+ ]
296
+ predictions.extend([None] * (s.horizon - len(predictions)))
297
+ fallback = s.context_values[-1] if s.context_values else Decimal(0)
298
+ mape, apes, missing = mape_for_predictions(
299
+ s.target_values,
300
+ predictions,
301
+ missing_policy=args.missing_policy,
302
+ fallback=fallback,
303
+ )
304
+ naive2_mape, naive2_apes, _ = mape_for_predictions(
305
+ s.target_values,
306
+ s.naive2_forecast,
307
+ missing_policy="skip",
308
+ fallback=fallback,
309
+ )
310
+ record = {
311
+ "row_index": s.row_index,
312
+ "series_name": s.series_name,
313
+ "original_id": s.original_id,
314
+ "period": s.period,
315
+ "frequency": s.frequency,
316
+ "type": s.type,
317
+ "description": s.description,
318
+ "horizon": s.horizon,
319
+ "scale": s.scale,
320
+ "context_length": len(s.context_values),
321
+ "prompt_tokens": len(prompts[idx]),
322
+ "target": [str(x) for x in s.target_values],
323
+ "scaled_prediction_terms": scaled_terms,
324
+ "prediction": [decimal_or_none(x) for x in predictions],
325
+ "parsed_terms": len(scaled_terms),
326
+ "missing_terms": missing,
327
+ "malformed": malformed,
328
+ "mape": mape,
329
+ "apes": apes,
330
+ "naive2_forecast": [str(x) for x in s.naive2_forecast],
331
+ "naive2_mape": naive2_mape,
332
+ "naive2_apes": naive2_apes,
333
+ "generated_text": generated_text,
334
+ "finish_reason": finish_reason,
335
+ }
336
+ out.write(json.dumps(record) + "\n")
337
+ out.flush()
338
+ progress.update(1)
339
+
340
+ while todo_indices and len(finished) < len(todo_indices):
341
+ responses = gen.next()
342
+ if isinstance(responses, tuple) and len(responses) == 2:
343
+ prompt_responses, generation_responses = responses
344
+ elif isinstance(responses, list):
345
+ prompt_responses, generation_responses = [], responses
346
+ else:
347
+ raise RuntimeError(
348
+ "Unexpected mlx_lm BatchGenerator.next() API. Update mlx_lm version."
349
+ )
350
+ if not prompt_responses and not generation_responses:
351
+ break
352
+ remove_uids: list[int] = []
353
+ for response in generation_responses:
354
+ uid = response.uid
355
+ if uid in finished:
356
+ continue
357
+ idx = uid_to_idx[uid]
358
+ horizon = series[idx].horizon
359
+ if response.finish_reason != "stop":
360
+ token = int(response.token)
361
+ generated_tokens[uid].append(token)
362
+ if token == sep_token:
363
+ separator_counts[uid] += 1
364
+ if separator_counts[uid] >= horizon:
365
+ finalize_uid(uid, "separator_count")
366
+ remove_uids.append(uid)
367
+ continue
368
+ if response.finish_reason is not None:
369
+ finalize_uid(uid, str(response.finish_reason))
370
+ if remove_uids:
371
+ gen.remove(remove_uids)
372
+ if len(finished) != len(todo_indices):
373
+ raise RuntimeError(f"Finished {len(finished)}/{len(todo_indices)} series")
374
+ finally:
375
+ gen.close()
376
+ mx.clear_cache()
377
+ gc.collect()
378
+
379
+ eval_index_set = set(eval_indices)
380
+ all_records = [
381
+ r for idx, r in sorted(load_completed(args.output).items()) if idx in eval_index_set
382
+ ]
383
+ aggregate = aggregate_summary(all_records)
384
+ elapsed = time.perf_counter() - started
385
+ summary = {
386
+ "data_path": str(args.data_path),
387
+ "model": str(args.model),
388
+ "output": str(args.output),
389
+ "series_loaded": len(series),
390
+ "series_evaluated": len(all_records),
391
+ "skipped_long": skipped_long,
392
+ "max_context_tokens": args.max_context_tokens,
393
+ "max_new_tokens": args.max_new_tokens,
394
+ "batch_size": args.batch_size,
395
+ "missing_policy": args.missing_policy,
396
+ "suppress_eos_until_horizon": not args.allow_eos_before_horizon,
397
+ "seconds": elapsed,
398
+ **aggregate,
399
+ }
400
+ args.summary_output.write_text(json.dumps(summary, indent=2) + "\n", encoding="utf-8")
401
+ print(
402
+ json.dumps(
403
+ {
404
+ k: summary[k]
405
+ for k in [
406
+ "series_evaluated",
407
+ "model_macro_mape",
408
+ "model_point_mape",
409
+ "naive2_macro_mape",
410
+ "naive2_point_mape",
411
+ "parsed_full_series",
412
+ "total_missing_terms",
413
+ "malformed_series",
414
+ "seconds",
415
+ ]
416
+ },
417
+ indent=2,
418
+ )
419
+ )
420
+ print(f"Wrote {args.output}")
421
+ print(f"Wrote {args.summary_output}")
422
+
423
+
424
+ if __name__ == "__main__":
425
+ main()
eval_results.txt ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ OEIS-Eval-Neo:
2
+ NextTerm-440M - 34.43%
3
+ NextTerm-47M - 29.49%
4
+ Qwen3-0.6B - 18.44%
5
+ Qwen3-1.7B - 20.77%
6
+ Qwen3-4B - 23.74%
7
+ Qwen3-8B - 24.62%
8
+ Qwen3-14B - 26.00%
9
+
10
+ Ryskina & Knight (2021):
11
+ NextTerm-440M - 52.63%
12
+ NextTerm-47M - 70.18%
13
+
14
+ 16-Shot Bitstring:
15
+ NextTerm-440M - 32.00%
16
+ NextTerm-47M - 27.88%
17
+
18
+ M1 Competition 111 MAPE (macro; canonical greedy; lower is better):
19
+ Naive2 - 17.7987
20
+ NextTerm-440M - 17.6239
21
+ NextTerm-47M - 18.7621
22
+ Qwen3-0.6B - 22.7984
23
+ Qwen3-1.7B - 22.2411
24
+ Qwen3-4B - 19.1731
25
+ Qwen3-8B - 18.4027
26
+ Qwen3-14B - 17.9837
27
+
28
+ M1 Competition 111 by frequency (macro MAPE):
29
+ Naive2 - monthly 17.3871 / quarterly 19.5958 / yearly 17.1314
30
+ NextTerm-440M - monthly 18.3407 / quarterly 19.0475 / yearly 13.5498
31
+ NextTerm-47M - monthly 21.2719 / quarterly 16.0270 / yearly 13.3741
32
+ Qwen3-0.6B - monthly 24.7585 / quarterly 21.7989 / yearly 17.2835
33
+ Qwen3-1.7B - monthly 24.3821 / quarterly 22.5869 / yearly 14.5642
34
+ Qwen3-4B - monthly 20.6455 / quarterly 19.6394 / yearly 13.6308
35
+ Qwen3-8B - monthly 20.0289 / quarterly 17.7249 / yearly 13.6534
36
+ Qwen3-14B - monthly 19.4006 / quarterly 18.1729 / yearly 12.9486
37
+
38
+ Polynomial continuation evals (accuracy; 200 samples per k):
39
+ Arithmetic (k=2-25): NextTerm-440M - 94.38%; NextTerm-47M - 94.15%
40
+ Quadratic (k=3-25): NextTerm-440M - 86.39%; NextTerm-47M - 81.07%
41
+ Cubic (k=4-25): NextTerm-440M - 75.20%; NextTerm-47M - 37.43%
42
+ Quartic (k=5-25): NextTerm-440M - 67.83%; NextTerm-47M - 15.17%
oeis_eval_mlx_neo.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Evaluate NextTerm on OEIS Eval Neo with MLX-LM BatchGenerator."""
2
+
3
+ import argparse
4
+ import gc
5
+ import inspect
6
+ import json
7
+ import time
8
+ from pathlib import Path
9
+
10
+ import mlx.core as mx
11
+ from mlx_lm import load
12
+ from mlx_lm.generate import BatchGenerator
13
+ from tqdm import tqdm
14
+
15
+ DATA_PATH = Path("/Users/natebreslow/Documents/khashabiLab/bigOEIS/oeis_val_neo.jsonl")
16
+ MODEL_NAME = "/Users/natebreslow/Documents/khashabiLab/bigOEIS/NextTerm-440M"
17
+ MAX_NEW_TOKENS = 196
18
+ MAX_CONTEXT_TOKENS = 4096
19
+ BATCH_SIZE = 64
20
+ OUTPUT_PATH = Path("oeis_eval_results/oeis_eval_mlx_neo_per_doc.jsonl")
21
+ SUMMARY_PATH = Path("oeis_eval_results/oeis_eval_mlx_neo_summary.json")
22
+ PARSE_ERROR_PRINT_LIMIT = 25
23
+ parse_error_print_count = 0
24
+
25
+
26
+ def parse_generated(text: str) -> int | None:
27
+ global parse_error_print_count
28
+ if "," in text:
29
+ text = text.split(",")[0]
30
+ try:
31
+ return int(text)
32
+ except ValueError:
33
+ if parse_error_print_count < PARSE_ERROR_PRINT_LIMIT:
34
+ print(f"Could not parse generated text: {text!r}")
35
+ parse_error_print_count += 1
36
+ return None
37
+
38
+
39
+ def load_sequences(path: Path):
40
+ sequences = []
41
+ answers = []
42
+ with path.open() as f:
43
+ for line in f:
44
+ record = json.loads(line, parse_int=str)
45
+ seq = record.get("seq", [])
46
+ if len(seq) < 2:
47
+ continue
48
+ sequences.append(seq[:-1])
49
+ answers.append(str(seq[-1]))
50
+ return sequences, answers
51
+
52
+
53
+ def parse_args():
54
+ parser = argparse.ArgumentParser()
55
+ parser.add_argument("--data-path", type=Path, default=DATA_PATH)
56
+ parser.add_argument("--model", default=MODEL_NAME)
57
+ parser.add_argument("--max-new-tokens", type=int, default=MAX_NEW_TOKENS)
58
+ parser.add_argument("--max-context-tokens", type=int, default=MAX_CONTEXT_TOKENS)
59
+ parser.add_argument("--batch-size", type=int, default=BATCH_SIZE)
60
+ parser.add_argument("--max-examples", type=int, default=0)
61
+ parser.add_argument("--output", type=Path, default=OUTPUT_PATH)
62
+ parser.add_argument("--summary-output", type=Path, default=SUMMARY_PATH)
63
+ parser.add_argument("--overwrite", action="store_true")
64
+ parser.add_argument("--restrict-digit-comma-eos", action="store_true")
65
+ parser.add_argument("--restrict-integer-comma-eos", action="store_true")
66
+ return parser.parse_args()
67
+
68
+
69
+ def load_completed(path: Path) -> dict[int, dict]:
70
+ completed = {}
71
+ if not path.exists():
72
+ return completed
73
+ with path.open("r", encoding="utf-8") as f:
74
+ for line in f:
75
+ if not line.strip():
76
+ continue
77
+ record = json.loads(line)
78
+ completed[int(record["row_index"])] = record
79
+ return completed
80
+
81
+
82
+ def encode_no_special(tokenizer, text: str) -> list[int]:
83
+ try:
84
+ return tokenizer.encode(text, add_special_tokens=False)
85
+ except TypeError:
86
+ return tokenizer.encode(text)
87
+
88
+
89
+ def normalize_stop_tokens_for_batch_generator(stop_tokens: list[list[int]]):
90
+ annotation = inspect.signature(BatchGenerator.__init__).parameters[
91
+ "stop_tokens"
92
+ ].annotation
93
+ if "set" in str(annotation):
94
+ return {seq[0] for seq in stop_tokens if len(seq) == 1}
95
+ return stop_tokens
96
+
97
+
98
+ def split_batch_generator_responses(responses):
99
+ if isinstance(responses, tuple) and len(responses) == 2:
100
+ prompt_responses, generation_responses = responses
101
+ return prompt_responses, generation_responses
102
+ if isinstance(responses, list):
103
+ return [], responses
104
+ raise RuntimeError(
105
+ "Unexpected mlx_lm BatchGenerator.next() API. Update this script for "
106
+ f"{type(responses).__name__}: {responses!r}"
107
+ )
108
+
109
+
110
+ def make_integer_comma_eos_processor(tokenizer):
111
+ allowed = set()
112
+ for text in [str(i) for i in range(10)] + ["-", ","]:
113
+ tokens = encode_no_special(tokenizer, text)
114
+ if len(tokens) == 1:
115
+ allowed.add(int(tokens[0]))
116
+ else:
117
+ print(f"Skipping multi-token allowed text {text!r}: {tokens}")
118
+ if tokenizer.eos_token_id is not None:
119
+ allowed.add(int(tokenizer.eos_token_id))
120
+ allowed_ids = sorted(allowed)
121
+ mask_cache = {}
122
+
123
+ def processor(_tokens, logits):
124
+ vocab_size = logits.shape[-1]
125
+ mask = mask_cache.get(vocab_size)
126
+ if mask is None:
127
+ values = [-1e9] * vocab_size
128
+ for token_id in allowed_ids:
129
+ if 0 <= token_id < vocab_size:
130
+ values[token_id] = 0.0
131
+ mask = mx.array(values, dtype=logits.dtype)
132
+ mask_cache[vocab_size] = mask
133
+ return logits + mask[None, :]
134
+
135
+ return processor, allowed_ids
136
+
137
+
138
+ def run_generation_queue(
139
+ *,
140
+ model,
141
+ tokenizer,
142
+ prompts,
143
+ answers,
144
+ row_indices,
145
+ stop_tokens,
146
+ max_new_tokens: int,
147
+ batch_size: int,
148
+ output_file,
149
+ progress,
150
+ logits_processors=None,
151
+ ) -> None:
152
+ gen = BatchGenerator(
153
+ model,
154
+ stop_tokens=normalize_stop_tokens_for_batch_generator(stop_tokens),
155
+ logits_processors=logits_processors,
156
+ completion_batch_size=batch_size,
157
+ prefill_batch_size=batch_size,
158
+ )
159
+ uids = gen.insert(prompts, [max_new_tokens] * len(prompts))
160
+ uid_to_pos = {uid: pos for pos, uid in enumerate(uids)}
161
+ generated_tokens = {uid: [] for uid in uids}
162
+ finished = set()
163
+ try:
164
+ while True:
165
+ responses = gen.next()
166
+ prompt_responses, generation_responses = split_batch_generator_responses(
167
+ responses
168
+ )
169
+ if not prompt_responses and not generation_responses:
170
+ break
171
+ if not generation_responses:
172
+ continue
173
+ for response in generation_responses:
174
+ uid = response.uid
175
+ if response.finish_reason != "stop":
176
+ generated_tokens[uid].append(int(response.token))
177
+ if response.finish_reason is None or uid in finished:
178
+ continue
179
+ finished.add(uid)
180
+ pos = uid_to_pos[uid]
181
+ text = tokenizer.decode(generated_tokens[uid])
182
+ prediction = parse_generated(text)
183
+ answer = answers[pos]
184
+ answer_int = int(answer)
185
+ record = {
186
+ "row_index": row_indices[pos],
187
+ "answer": answer,
188
+ "prediction": prediction,
189
+ "correct": prediction == answer_int,
190
+ "parsed": prediction is not None,
191
+ "generated_text": text,
192
+ "generated_tokens": generated_tokens[uid],
193
+ "finish_reason": response.finish_reason,
194
+ }
195
+ output_file.write(json.dumps(record) + "\n")
196
+ progress.update(1)
197
+ finally:
198
+ gen.close()
199
+ mx.clear_cache()
200
+ gc.collect()
201
+
202
+ if len(finished) != len(uids):
203
+ raise RuntimeError(f"Chunk finished {len(finished)}/{len(uids)} rows")
204
+
205
+
206
+ def main():
207
+ args = parse_args()
208
+ started = time.perf_counter()
209
+ sequences, answers = load_sequences(args.data_path)
210
+ if args.max_examples > 0:
211
+ sequences = sequences[: args.max_examples]
212
+ answers = answers[: args.max_examples]
213
+ print(f"Loaded {len(answers)} sequences from {args.data_path}")
214
+
215
+ model, tokenizer = load(args.model)
216
+
217
+ sep_tokens = encode_no_special(tokenizer, ",")
218
+ if not sep_tokens:
219
+ sep_tokens = encode_no_special(tokenizer, "1,")[-1:]
220
+ prompts = [",".join(str(x) for x in seq) + "," for seq in sequences]
221
+ prompts = [tokenizer.encode(p) for p in prompts]
222
+
223
+ eval_indices = [
224
+ i
225
+ for i, prompt in enumerate(prompts)
226
+ if args.max_context_tokens <= 0 or len(prompt) < args.max_context_tokens
227
+ ]
228
+ skipped_long = len(prompts) - len(eval_indices)
229
+ eval_indices = sorted(eval_indices, key=lambda i: len(prompts[i]))
230
+ print(
231
+ f"Evaluating {len(eval_indices)} rows; skipped_long={skipped_long}; "
232
+ f"sep_tokens={sep_tokens}; eos_token={tokenizer.eos_token_id}; "
233
+ f"max_new_tokens={args.max_new_tokens}; batch_size={args.batch_size}"
234
+ )
235
+
236
+ stop_tokens = [sep_tokens]
237
+ if tokenizer.eos_token_id is not None:
238
+ stop_tokens.append([tokenizer.eos_token_id])
239
+ logits_processors = None
240
+ allowed_token_ids = None
241
+ restrict_integer = args.restrict_digit_comma_eos or args.restrict_integer_comma_eos
242
+ if restrict_integer:
243
+ processor, allowed_token_ids = make_integer_comma_eos_processor(tokenizer)
244
+ logits_processors = [processor]
245
+ print(f"Restricting logits to integer/comma/EOS token ids: {allowed_token_ids}")
246
+ args.output.parent.mkdir(parents=True, exist_ok=True)
247
+ args.summary_output.parent.mkdir(parents=True, exist_ok=True)
248
+ if args.overwrite and args.output.exists():
249
+ args.output.unlink()
250
+ completed = load_completed(args.output)
251
+ if completed:
252
+ print(f"Resuming from {args.output}: {len(completed)} rows already done")
253
+ todo_indices = [idx for idx in eval_indices if idx not in completed]
254
+
255
+ with args.output.open("a", encoding="utf-8") as output_file:
256
+ with tqdm(total=len(todo_indices), desc="Generating") as progress:
257
+ run_generation_queue(
258
+ model=model,
259
+ tokenizer=tokenizer,
260
+ prompts=[prompts[i] for i in todo_indices],
261
+ answers=[answers[i] for i in todo_indices],
262
+ row_indices=todo_indices,
263
+ stop_tokens=stop_tokens,
264
+ max_new_tokens=args.max_new_tokens,
265
+ batch_size=args.batch_size,
266
+ output_file=output_file,
267
+ progress=progress,
268
+ logits_processors=logits_processors,
269
+ )
270
+ output_file.flush()
271
+
272
+ records = load_completed(args.output)
273
+ eval_set = set(eval_indices)
274
+ records = {idx: record for idx, record in records.items() if idx in eval_set}
275
+ correct = sum(1 for record in records.values() if record["correct"])
276
+ parsed = sum(1 for record in records.values() if record["parsed"])
277
+ total = len(eval_indices)
278
+ evaluated = len(records)
279
+ elapsed = time.perf_counter() - started
280
+
281
+ print(f"Documents: {len(answers)}")
282
+ print(f"Evaluated: {evaluated}/{total}")
283
+ print(f"Skipped long: {skipped_long}")
284
+ print(f"Parsed predictions: {parsed}/{evaluated}")
285
+ print(f"Accuracy: {correct}/{evaluated} = {correct / evaluated:.4f}")
286
+ summary = {
287
+ "data_path": str(args.data_path),
288
+ "model": args.model,
289
+ "output": str(args.output),
290
+ "documents": len(answers),
291
+ "evaluated": evaluated,
292
+ "expected_evaluated": total,
293
+ "skipped_long": skipped_long,
294
+ "parsed": parsed,
295
+ "correct": correct,
296
+ "accuracy": correct / evaluated if evaluated else 0.0,
297
+ "max_new_tokens": args.max_new_tokens,
298
+ "max_context_tokens": args.max_context_tokens,
299
+ "batch_size": args.batch_size,
300
+ "restrict_digit_comma_eos": args.restrict_digit_comma_eos,
301
+ "restrict_integer_comma_eos": restrict_integer,
302
+ "allowed_token_ids": allowed_token_ids,
303
+ "seconds": elapsed,
304
+ }
305
+ args.summary_output.write_text(json.dumps(summary, indent=2) + "\n", encoding="utf-8")
306
+ print(f"Wrote {args.output}")
307
+ print(f"Wrote {args.summary_output}")
308
+
309
+ if __name__ == "__main__":
310
+ main()
radar_chart.png ADDED

Git LFS Details

  • SHA256: e86a7da31e209952caa5f8de431eb1813c801f0f8fcffaf462f8b828a84fc8ff
  • Pointer size: 131 Bytes
  • Size of remote file: 411 kB