N8Programs commited on
Commit
a46649b
·
verified ·
1 Parent(s): 5db721a

Bundle evaluation datasets with model card scripts

Browse files
README.md CHANGED
@@ -142,11 +142,13 @@ text before the first comma as the predicted integer.
142
  ## Reproducibility
143
 
144
  This repository contains the local evaluation scripts and artifacts used for the
145
- results above, including:
146
 
147
  - `oeis_eval_mlx_neo.py` for OEIS-Eval-Neo with MLX batch generation.
148
  - `arithmetic_eval.py` for arithmetic/quadratic/cubic/quartic continuation.
149
  - `eval_m1_competition_mape_mlx.py` for M1 Competition 111 MAPE.
 
 
150
  - `eval_results.txt` for the compact result table.
151
 
152
 
 
142
  ## Reproducibility
143
 
144
  This repository contains the local evaluation scripts and artifacts used for the
145
+ results above, including the small evaluation datasets needed to rerun them:
146
 
147
  - `oeis_eval_mlx_neo.py` for OEIS-Eval-Neo with MLX batch generation.
148
  - `arithmetic_eval.py` for arithmetic/quadratic/cubic/quartic continuation.
149
  - `eval_m1_competition_mape_mlx.py` for M1 Competition 111 MAPE.
150
+ - `oeis_val_neo.jsonl` for OEIS-Eval-Neo.
151
+ - `m1_competition_111.jsonl` for M1 Competition 111.
152
  - `eval_results.txt` for the compact result table.
153
 
154
 
arithmetic_eval.py CHANGED
@@ -19,6 +19,7 @@ import argparse
19
  import csv
20
  import random
21
  from dataclasses import dataclass
 
22
  from typing import Dict, List, Tuple
23
 
24
  from mlx_lm import load
@@ -26,8 +27,19 @@ from mlx_lm.generate import BatchGenerator
26
  from mlx_lm.tuner.utils import print_trainable_parameters
27
  from tqdm import tqdm
28
 
29
- # Step 335000
30
- MODEL_NAME = "NextTerm-47M"
 
 
 
 
 
 
 
 
 
 
 
31
  MAX_NEW_TOKENS = 20
32
  DEFAULT_MAX_TERMS = 25
33
  SAMPLES_PER_K = 200
 
19
  import csv
20
  import random
21
  from dataclasses import dataclass
22
+ from pathlib import Path
23
  from typing import Dict, List, Tuple
24
 
25
  from mlx_lm import load
 
27
  from mlx_lm.tuner.utils import print_trainable_parameters
28
  from tqdm import tqdm
29
 
30
+ SCRIPT_DIR = Path(__file__).resolve().parent
31
+
32
+
33
+ def default_model_path() -> str:
34
+ if (SCRIPT_DIR / "model.safetensors").exists():
35
+ return str(SCRIPT_DIR)
36
+ local_model = SCRIPT_DIR / "NextTerm-47M"
37
+ if local_model.exists():
38
+ return str(local_model)
39
+ return "N8Programs/NextTerm-47M"
40
+
41
+
42
+ MODEL_NAME = default_model_path()
43
  MAX_NEW_TOKENS = 20
44
  DEFAULT_MAX_TERMS = 25
45
  SAMPLES_PER_K = 200
eval_m1_competition_mape_mlx.py CHANGED
@@ -28,8 +28,20 @@ from eval_m1_monthly_mape_mlx import (
28
  )
29
 
30
 
31
- DATA_PATH = Path("m1_competition_111.jsonl")
32
- MODEL_PATH = Path("NextTerm-440M")
 
 
 
 
 
 
 
 
 
 
 
 
33
  OUTPUT_PATH = Path("m1_eval_results/m1_competition111_nextterm440m_greedy_per_series.jsonl")
34
  SUMMARY_PATH = Path("m1_eval_results/m1_competition111_nextterm440m_greedy_summary.json")
35
 
 
28
  )
29
 
30
 
31
+ SCRIPT_DIR = Path(__file__).resolve().parent
32
+
33
+
34
+ def default_model_path() -> Path:
35
+ if (SCRIPT_DIR / "model.safetensors").exists():
36
+ return SCRIPT_DIR
37
+ local_model = SCRIPT_DIR / "NextTerm-440M"
38
+ if local_model.exists():
39
+ return local_model
40
+ return Path("N8Programs/NextTerm-440M")
41
+
42
+
43
+ DATA_PATH = SCRIPT_DIR / "m1_competition_111.jsonl"
44
+ MODEL_PATH = default_model_path()
45
  OUTPUT_PATH = Path("m1_eval_results/m1_competition111_nextterm440m_greedy_per_series.jsonl")
46
  SUMMARY_PATH = Path("m1_eval_results/m1_competition111_nextterm440m_greedy_summary.json")
47
 
eval_m1_monthly_mape_mlx.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Evaluate NextTerm-style MLX models on the M1 monthly forecasting dataset."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import gc
8
+ import json
9
+ import math
10
+ import re
11
+ import time
12
+ from dataclasses import dataclass
13
+ from decimal import Decimal, InvalidOperation, localcontext
14
+ from pathlib import Path
15
+ from statistics import mean, median
16
+
17
+ import mlx.core as mx
18
+ from mlx_lm import load
19
+ from mlx_lm.generate import BatchGenerator
20
+ from tqdm import tqdm
21
+
22
+
23
+ SCRIPT_DIR = Path(__file__).resolve().parent
24
+
25
+
26
+ def default_model_path() -> Path:
27
+ if (SCRIPT_DIR / "model.safetensors").exists():
28
+ return SCRIPT_DIR
29
+ local_model = SCRIPT_DIR / "NextTerm-47M"
30
+ if local_model.exists():
31
+ return local_model
32
+ return Path("N8Programs/NextTerm-47M")
33
+
34
+
35
+ DATA_PATH = SCRIPT_DIR / "m1_monthly_dataset.txt"
36
+ MODEL_PATH = default_model_path()
37
+ OUTPUT_PATH = Path("m1_eval_results/m1_monthly_nextterm47m_per_series.jsonl")
38
+ SUMMARY_PATH = Path("m1_eval_results/m1_monthly_nextterm47m_summary.json")
39
+
40
+
41
+ @dataclass
42
+ class SeriesRecord:
43
+ row_index: int
44
+ series_name: str
45
+ start_timestamp: str
46
+ raw_values: list[str]
47
+ values: list[Decimal]
48
+ context_values: list[Decimal]
49
+ target_values: list[Decimal]
50
+ scale: int
51
+ scaled_context: list[int]
52
+
53
+
54
+ class SuppressTokenLogits:
55
+ """Set selected token logits to a large negative value."""
56
+
57
+ def __init__(self, token_ids: list[int]):
58
+ self.token_ids = sorted({int(t) for t in token_ids if t is not None and int(t) >= 0})
59
+ self._bias_by_width: dict[int, mx.array] = {}
60
+
61
+ def __call__(self, tokens: mx.array, logits: mx.array) -> mx.array:
62
+ width = int(logits.shape[-1])
63
+ bias = self._bias_by_width.get(width)
64
+ if bias is None:
65
+ values = [0.0] * width
66
+ for token_id in self.token_ids:
67
+ if token_id < width:
68
+ values[token_id] = -1.0e9
69
+ bias = mx.array(values, dtype=logits.dtype)
70
+ self._bias_by_width[width] = bias
71
+ return logits + bias
72
+
73
+
74
+ def parse_decimal(raw: str) -> Decimal:
75
+ raw = raw.strip()
76
+ try:
77
+ value = Decimal(raw)
78
+ except InvalidOperation as exc:
79
+ raise ValueError(f"Could not parse decimal value {raw!r}") from exc
80
+ if not value.is_finite():
81
+ raise ValueError(f"Non-finite decimal value {raw!r}")
82
+ return value
83
+
84
+
85
+ def context_scale(raw_context_values: list[str]) -> int:
86
+ """Choose an integer scale using only visible decimal precision in context."""
87
+ max_places = 0
88
+ for raw in raw_context_values:
89
+ value = parse_decimal(raw)
90
+ exponent = value.as_tuple().exponent
91
+ if exponent < 0:
92
+ max_places = max(max_places, -exponent)
93
+ return 10**max_places
94
+
95
+
96
+ def scale_decimal_to_int(value: Decimal, scale: int) -> int:
97
+ scaled = value * Decimal(scale)
98
+ with localcontext() as ctx:
99
+ ctx.prec = max(50, len(scaled.as_tuple().digits) + 10)
100
+ rounded = scaled.to_integral_value()
101
+ if rounded != scaled:
102
+ # This can only happen if the held-out side has more precision than the
103
+ # context-derived scale. It is fine for scoring, but not for prompting.
104
+ raise ValueError(f"Context scale {scale} does not make {value} integral")
105
+ return int(rounded)
106
+
107
+
108
+ def load_m1_monthly(path: Path, horizon: int | None = None) -> tuple[list[SeriesRecord], int]:
109
+ records: list[SeriesRecord] = []
110
+ parsed_horizon: int | None = horizon
111
+ in_data = False
112
+ with path.open("r", encoding="latin-1", newline=None) as f:
113
+ for raw_line in f:
114
+ line = raw_line.strip()
115
+ if not line:
116
+ continue
117
+ if line.startswith("@horizon") and parsed_horizon is None:
118
+ parts = line.split()
119
+ if len(parts) >= 2:
120
+ parsed_horizon = int(parts[1])
121
+ continue
122
+ if line == "@data":
123
+ in_data = True
124
+ continue
125
+ if not in_data or line.startswith("#") or line.startswith("@"):
126
+ continue
127
+
128
+ try:
129
+ series_name, start_timestamp, values_blob = line.split(":", 2)
130
+ except ValueError as exc:
131
+ raise ValueError(f"Malformed M1 data line: {line[:120]!r}") from exc
132
+ raw_values = [part.strip() for part in values_blob.split(",") if part.strip()]
133
+ values = [parse_decimal(part) for part in raw_values]
134
+ if parsed_horizon is None:
135
+ raise ValueError("No horizon provided and no @horizon metadata found")
136
+ if len(values) <= parsed_horizon:
137
+ continue
138
+
139
+ raw_context = raw_values[:-parsed_horizon]
140
+ scale = context_scale(raw_context)
141
+ context_values = values[:-parsed_horizon]
142
+ target_values = values[-parsed_horizon:]
143
+ scaled_context = [scale_decimal_to_int(v, scale) for v in context_values]
144
+ records.append(
145
+ SeriesRecord(
146
+ row_index=len(records),
147
+ series_name=series_name,
148
+ start_timestamp=start_timestamp,
149
+ raw_values=raw_values,
150
+ values=values,
151
+ context_values=context_values,
152
+ target_values=target_values,
153
+ scale=scale,
154
+ scaled_context=scaled_context,
155
+ )
156
+ )
157
+ if parsed_horizon is None:
158
+ raise ValueError("No horizon provided and no @horizon metadata found")
159
+ return records, parsed_horizon
160
+
161
+
162
+ def parse_generated_terms(text: str, limit: int) -> tuple[list[int], bool]:
163
+ terms: list[int] = []
164
+ current: list[str] = []
165
+ malformed = False
166
+
167
+ def flush_current() -> None:
168
+ nonlocal malformed
169
+ if not current:
170
+ return
171
+ token = "".join(current)
172
+ current.clear()
173
+ if token == "-":
174
+ malformed = True
175
+ return
176
+ try:
177
+ terms.append(int(token))
178
+ except ValueError:
179
+ malformed = True
180
+
181
+ for ch in text:
182
+ if len(terms) >= limit:
183
+ break
184
+ if ch.isdigit() or (ch == "-" and not current):
185
+ current.append(ch)
186
+ elif ch == ",":
187
+ flush_current()
188
+ elif ch.isspace():
189
+ continue
190
+ else:
191
+ malformed = True
192
+ flush_current()
193
+ break
194
+ if len(terms) < limit:
195
+ flush_current()
196
+ return terms[:limit], malformed
197
+
198
+
199
+ def as_float(value: Decimal) -> float:
200
+ return float(value)
201
+
202
+
203
+ def ape(actual: Decimal, prediction: Decimal) -> float | None:
204
+ if actual == 0:
205
+ return None
206
+ return float(abs(actual - prediction) / abs(actual) * Decimal(100))
207
+
208
+
209
+ def mape_for_predictions(
210
+ actuals: list[Decimal],
211
+ predictions: list[Decimal | None],
212
+ *,
213
+ missing_policy: str,
214
+ fallback: Decimal,
215
+ ) -> tuple[float | None, list[float | None], int]:
216
+ apes: list[float | None] = []
217
+ missing = 0
218
+ for actual, prediction in zip(actuals, predictions):
219
+ pred = prediction
220
+ if pred is None:
221
+ missing += 1
222
+ if missing_policy == "skip":
223
+ apes.append(None)
224
+ continue
225
+ if missing_policy == "zero":
226
+ pred = Decimal(0)
227
+ elif missing_policy == "last_context":
228
+ pred = fallback
229
+ else:
230
+ raise ValueError(f"Unknown missing policy: {missing_policy}")
231
+ apes.append(ape(actual, pred))
232
+ valid = [x for x in apes if x is not None and math.isfinite(x)]
233
+ return (mean(valid) if valid else None), apes, missing
234
+
235
+
236
+ def seasonal_naive_predictions(context: list[Decimal], horizon: int, season: int = 12) -> list[Decimal]:
237
+ if not context:
238
+ return [Decimal(0)] * horizon
239
+ if len(context) < season:
240
+ return [context[-1]] * horizon
241
+ last_season = context[-season:]
242
+ return [last_season[i % season] for i in range(horizon)]
243
+
244
+
245
+ def last_value_predictions(context: list[Decimal], horizon: int) -> list[Decimal]:
246
+ fallback = context[-1] if context else Decimal(0)
247
+ return [fallback] * horizon
248
+
249
+
250
+ def load_completed(path: Path) -> dict[int, dict]:
251
+ completed: dict[int, dict] = {}
252
+ if not path.exists():
253
+ return completed
254
+ with path.open("r", encoding="utf-8") as f:
255
+ for line in f:
256
+ if not line.strip():
257
+ continue
258
+ record = json.loads(line)
259
+ completed[int(record["row_index"])] = record
260
+ return completed
261
+
262
+
263
+ def aggregate_summary(records: list[dict], horizon: int) -> dict:
264
+ def collect(key: str) -> list[float]:
265
+ return [
266
+ float(r[key])
267
+ for r in records
268
+ if r.get(key) is not None and math.isfinite(float(r[key]))
269
+ ]
270
+
271
+ model_series = collect("mape")
272
+ seasonal_series = collect("seasonal_naive_mape")
273
+ last_series = collect("last_value_mape")
274
+ per_horizon: list[dict] = []
275
+ for h in range(horizon):
276
+ vals = []
277
+ seasonal_vals = []
278
+ last_vals = []
279
+ for r in records:
280
+ for source, target in [
281
+ ("apes", vals),
282
+ ("seasonal_naive_apes", seasonal_vals),
283
+ ("last_value_apes", last_vals),
284
+ ]:
285
+ xs = r.get(source) or []
286
+ if h < len(xs) and xs[h] is not None and math.isfinite(float(xs[h])):
287
+ target.append(float(xs[h]))
288
+ per_horizon.append(
289
+ {
290
+ "horizon": h + 1,
291
+ "mape": mean(vals) if vals else None,
292
+ "seasonal_naive_mape": mean(seasonal_vals) if seasonal_vals else None,
293
+ "last_value_mape": mean(last_vals) if last_vals else None,
294
+ "n": len(vals),
295
+ }
296
+ )
297
+
298
+ all_apes = [
299
+ float(x)
300
+ for r in records
301
+ for x in (r.get("apes") or [])
302
+ if x is not None and math.isfinite(float(x))
303
+ ]
304
+ all_seasonal_apes = [
305
+ float(x)
306
+ for r in records
307
+ for x in (r.get("seasonal_naive_apes") or [])
308
+ if x is not None and math.isfinite(float(x))
309
+ ]
310
+ all_last_apes = [
311
+ float(x)
312
+ for r in records
313
+ for x in (r.get("last_value_apes") or [])
314
+ if x is not None and math.isfinite(float(x))
315
+ ]
316
+ return {
317
+ "series_count": len(records),
318
+ "model_macro_mape": mean(model_series) if model_series else None,
319
+ "model_median_series_mape": median(model_series) if model_series else None,
320
+ "model_point_mape": mean(all_apes) if all_apes else None,
321
+ "seasonal_naive_macro_mape": mean(seasonal_series) if seasonal_series else None,
322
+ "seasonal_naive_point_mape": mean(all_seasonal_apes) if all_seasonal_apes else None,
323
+ "last_value_macro_mape": mean(last_series) if last_series else None,
324
+ "last_value_point_mape": mean(all_last_apes) if all_last_apes else None,
325
+ "parsed_full_series": sum(1 for r in records if r.get("parsed_terms", 0) >= horizon),
326
+ "total_missing_terms": sum(int(r.get("missing_terms", 0)) for r in records),
327
+ "malformed_series": sum(1 for r in records if r.get("malformed", False)),
328
+ "per_horizon": per_horizon,
329
+ }
330
+
331
+
332
+ def parse_args() -> argparse.Namespace:
333
+ parser = argparse.ArgumentParser()
334
+ parser.add_argument("--data-path", type=Path, default=DATA_PATH)
335
+ parser.add_argument("--model", type=Path, default=MODEL_PATH)
336
+ parser.add_argument("--output", type=Path, default=OUTPUT_PATH)
337
+ parser.add_argument("--summary-output", type=Path, default=SUMMARY_PATH)
338
+ parser.add_argument("--horizon", type=int, default=None)
339
+ parser.add_argument("--batch-size", type=int, default=64)
340
+ parser.add_argument("--max-new-tokens", type=int, default=384)
341
+ parser.add_argument("--max-context-tokens", type=int, default=2048)
342
+ parser.add_argument("--max-series", type=int, default=0)
343
+ parser.add_argument(
344
+ "--suppress-eos-until-horizon",
345
+ action="store_true",
346
+ help="Suppress EOS/pad logits and stop only after the requested separator count or length cap.",
347
+ )
348
+ parser.add_argument(
349
+ "--rerun-incomplete",
350
+ action="store_true",
351
+ help="When resuming, rerun rows whose previous record parsed fewer than horizon terms.",
352
+ )
353
+ parser.add_argument(
354
+ "--missing-policy",
355
+ choices=["zero", "last_context", "skip"],
356
+ default="zero",
357
+ help="How to score missing/unparseable forecast terms.",
358
+ )
359
+ parser.add_argument("--overwrite", action="store_true")
360
+ return parser.parse_args()
361
+
362
+
363
+ def main() -> None:
364
+ args = parse_args()
365
+ started = time.perf_counter()
366
+ series, horizon = load_m1_monthly(args.data_path, args.horizon)
367
+ if args.max_series > 0:
368
+ series = series[: args.max_series]
369
+ print(f"Loaded {len(series)} M1 monthly series from {args.data_path}; horizon={horizon}")
370
+
371
+ model, tokenizer = load(str(args.model))
372
+ print(f"Loaded model: {args.model}")
373
+
374
+ prompts_text = [",".join(str(x) for x in s.scaled_context) + "," for s in series]
375
+ prompts = [tokenizer.encode(text) for text in prompts_text]
376
+ sep_token = tokenizer.encode("1,")[-1]
377
+ eval_indices = [
378
+ i
379
+ for i, prompt in enumerate(prompts)
380
+ if args.max_context_tokens <= 0 or len(prompt) < args.max_context_tokens
381
+ ]
382
+ skipped_long = len(prompts) - len(eval_indices)
383
+ eval_indices = sorted(eval_indices, key=lambda i: len(prompts[i]))
384
+ print(
385
+ f"Evaluating {len(eval_indices)} series; skipped_long={skipped_long}; "
386
+ f"batch_size={args.batch_size}; max_new_tokens={args.max_new_tokens}; sep_token={sep_token}"
387
+ )
388
+
389
+ args.output.parent.mkdir(parents=True, exist_ok=True)
390
+ args.summary_output.parent.mkdir(parents=True, exist_ok=True)
391
+ if args.overwrite and args.output.exists():
392
+ args.output.unlink()
393
+
394
+ completed = load_completed(args.output)
395
+ if args.rerun_incomplete:
396
+ incomplete = [
397
+ idx
398
+ for idx, record in completed.items()
399
+ if int(record.get("parsed_terms", 0)) < horizon
400
+ ]
401
+ for idx in incomplete:
402
+ completed.pop(idx, None)
403
+ if incomplete:
404
+ print(
405
+ "Rerunning incomplete rows: "
406
+ + ", ".join(str(idx) for idx in sorted(incomplete))
407
+ )
408
+ if completed:
409
+ print(f"Resuming from {args.output}: {len(completed)} series already done")
410
+ todo_indices = [i for i in eval_indices if i not in completed]
411
+
412
+ stop_tokens = [] if args.suppress_eos_until_horizon else [[tokenizer.eos_token_id]]
413
+ suppress_token_ids: list[int] = []
414
+ if getattr(tokenizer, "eos_token_id", None) is not None:
415
+ suppress_token_ids.append(tokenizer.eos_token_id)
416
+ if getattr(tokenizer, "pad_token_id", None) is not None:
417
+ if not args.suppress_eos_until_horizon:
418
+ stop_tokens.append([tokenizer.pad_token_id])
419
+ suppress_token_ids.append(tokenizer.pad_token_id)
420
+ suppress_processor = (
421
+ SuppressTokenLogits(suppress_token_ids) if args.suppress_eos_until_horizon else None
422
+ )
423
+ gen = BatchGenerator(
424
+ model,
425
+ stop_tokens=stop_tokens or None,
426
+ completion_batch_size=args.batch_size,
427
+ prefill_batch_size=args.batch_size,
428
+ )
429
+ uid_to_idx: dict[int, int] = {}
430
+ generated_tokens: dict[int, list[int]] = {}
431
+ separator_counts: dict[int, int] = {}
432
+ if todo_indices:
433
+ logits_processors = (
434
+ [[suppress_processor] for _ in todo_indices]
435
+ if suppress_processor is not None
436
+ else None
437
+ )
438
+ uids = gen.insert(
439
+ [prompts[i] for i in todo_indices],
440
+ [args.max_new_tokens] * len(todo_indices),
441
+ logits_processors=logits_processors,
442
+ )
443
+ uid_to_idx = {uid: idx for uid, idx in zip(uids, todo_indices)}
444
+ generated_tokens = {uid: [] for uid in uids}
445
+ separator_counts = {uid: 0 for uid in uids}
446
+ finished: set[int] = set()
447
+
448
+ try:
449
+ with args.output.open("a", encoding="utf-8") as out:
450
+ with tqdm(total=len(todo_indices), desc="Generating") as progress:
451
+ def finalize_uid(uid: int, finish_reason: str) -> None:
452
+ finished.add(uid)
453
+ idx = uid_to_idx[uid]
454
+ s = series[idx]
455
+ generated_text = tokenizer.decode(generated_tokens[uid])
456
+ scaled_terms, malformed = parse_generated_terms(generated_text, horizon)
457
+ predictions: list[Decimal | None] = [
458
+ Decimal(term) / Decimal(s.scale) for term in scaled_terms
459
+ ]
460
+ predictions.extend([None] * (horizon - len(predictions)))
461
+ fallback = s.context_values[-1] if s.context_values else Decimal(0)
462
+ mape, apes, missing = mape_for_predictions(
463
+ s.target_values,
464
+ predictions,
465
+ missing_policy=args.missing_policy,
466
+ fallback=fallback,
467
+ )
468
+ seasonal_preds = seasonal_naive_predictions(s.context_values, horizon)
469
+ seasonal_mape, seasonal_apes, _ = mape_for_predictions(
470
+ s.target_values,
471
+ seasonal_preds,
472
+ missing_policy="skip",
473
+ fallback=fallback,
474
+ )
475
+ last_preds = last_value_predictions(s.context_values, horizon)
476
+ last_mape, last_apes, _ = mape_for_predictions(
477
+ s.target_values,
478
+ last_preds,
479
+ missing_policy="skip",
480
+ fallback=fallback,
481
+ )
482
+ record = {
483
+ "row_index": idx,
484
+ "series_name": s.series_name,
485
+ "start_timestamp": s.start_timestamp,
486
+ "scale": s.scale,
487
+ "context_length": len(s.context_values),
488
+ "prompt_tokens": len(prompts[idx]),
489
+ "target": [str(x) for x in s.target_values],
490
+ "scaled_prediction_terms": scaled_terms,
491
+ "prediction": [str(x) if x is not None else None for x in predictions],
492
+ "parsed_terms": len(scaled_terms),
493
+ "missing_terms": missing,
494
+ "malformed": malformed,
495
+ "mape": mape,
496
+ "apes": apes,
497
+ "seasonal_naive_mape": seasonal_mape,
498
+ "seasonal_naive_apes": seasonal_apes,
499
+ "last_value_mape": last_mape,
500
+ "last_value_apes": last_apes,
501
+ "generated_text": generated_text,
502
+ "finish_reason": finish_reason,
503
+ }
504
+ out.write(json.dumps(record) + "\n")
505
+ out.flush()
506
+ progress.update(1)
507
+
508
+ while todo_indices:
509
+ responses = gen.next()
510
+ if not isinstance(responses, tuple) or len(responses) != 2:
511
+ raise RuntimeError(
512
+ "Unexpected mlx_lm BatchGenerator.next() API. "
513
+ "Update your mlx_lm version."
514
+ )
515
+ prompt_responses, generation_responses = responses
516
+ if not prompt_responses and not generation_responses:
517
+ break
518
+ remove_uids: list[int] = []
519
+ for response in generation_responses:
520
+ uid = response.uid
521
+ if uid in finished:
522
+ continue
523
+ if response.finish_reason != "stop":
524
+ token = int(response.token)
525
+ generated_tokens[uid].append(token)
526
+ if token == sep_token:
527
+ separator_counts[uid] += 1
528
+ if separator_counts[uid] >= horizon:
529
+ finalize_uid(uid, "separator_count")
530
+ remove_uids.append(uid)
531
+ continue
532
+ if response.finish_reason is not None:
533
+ finalize_uid(uid, str(response.finish_reason))
534
+ if remove_uids:
535
+ gen.remove(remove_uids)
536
+ if len(finished) != len(todo_indices):
537
+ raise RuntimeError(f"Finished {len(finished)}/{len(todo_indices)} series")
538
+ finally:
539
+ gen.close()
540
+ mx.clear_cache()
541
+ gc.collect()
542
+
543
+ all_records = [
544
+ r
545
+ for idx, r in sorted(load_completed(args.output).items())
546
+ if idx in set(eval_indices)
547
+ ]
548
+ aggregate = aggregate_summary(all_records, horizon)
549
+ elapsed = time.perf_counter() - started
550
+ summary = {
551
+ "data_path": str(args.data_path),
552
+ "model": str(args.model),
553
+ "output": str(args.output),
554
+ "horizon": horizon,
555
+ "series_loaded": len(series),
556
+ "series_evaluated": len(all_records),
557
+ "skipped_long": skipped_long,
558
+ "max_context_tokens": args.max_context_tokens,
559
+ "max_new_tokens": args.max_new_tokens,
560
+ "batch_size": args.batch_size,
561
+ "missing_policy": args.missing_policy,
562
+ "suppress_eos_until_horizon": args.suppress_eos_until_horizon,
563
+ "seconds": elapsed,
564
+ **aggregate,
565
+ }
566
+ args.summary_output.write_text(json.dumps(summary, indent=2) + "\n", encoding="utf-8")
567
+ print(json.dumps({k: summary[k] for k in [
568
+ "series_evaluated",
569
+ "model_macro_mape",
570
+ "model_point_mape",
571
+ "seasonal_naive_macro_mape",
572
+ "last_value_macro_mape",
573
+ "parsed_full_series",
574
+ "total_missing_terms",
575
+ "malformed_series",
576
+ "seconds",
577
+ ]}, indent=2))
578
+ print(f"Wrote {args.output}")
579
+ print(f"Wrote {args.summary_output}")
580
+
581
+
582
+ if __name__ == "__main__":
583
+ main()
m1_competition_111.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
m1_monthly_dataset.txt ADDED
The diff for this file is too large to render. See raw diff
 
oeis_eval_mlx_neo.py CHANGED
@@ -12,8 +12,20 @@ from mlx_lm import load
12
  from mlx_lm.generate import BatchGenerator
13
  from tqdm import tqdm
14
 
15
- DATA_PATH = Path("/Users/natebreslow/Documents/khashabiLab/bigOEIS/oeis_val_neo.jsonl")
16
- MODEL_NAME = "/Users/natebreslow/Documents/khashabiLab/bigOEIS/NextTerm-440M"
 
 
 
 
 
 
 
 
 
 
 
 
17
  MAX_NEW_TOKENS = 196
18
  MAX_CONTEXT_TOKENS = 4096
19
  BATCH_SIZE = 64
 
12
  from mlx_lm.generate import BatchGenerator
13
  from tqdm import tqdm
14
 
15
+ SCRIPT_DIR = Path(__file__).resolve().parent
16
+
17
+
18
+ def default_model_path() -> str:
19
+ if (SCRIPT_DIR / "model.safetensors").exists():
20
+ return str(SCRIPT_DIR)
21
+ local_model = SCRIPT_DIR / "NextTerm-440M"
22
+ if local_model.exists():
23
+ return str(local_model)
24
+ return "N8Programs/NextTerm-440M"
25
+
26
+
27
+ DATA_PATH = SCRIPT_DIR / "oeis_val_neo.jsonl"
28
+ MODEL_NAME = default_model_path()
29
  MAX_NEW_TOKENS = 196
30
  MAX_CONTEXT_TOKENS = 4096
31
  BATCH_SIZE = 64
oeis_val_neo.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
oeis_val_neo.meta.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "oeis_val_neo",
3
+ "description": "OEIS Eval Neo: oeis_val_decontam with exact packed-sequence overlaps against the synthetic training shard removed.",
4
+ "source_jsonl": "/Users/natebreslow/Documents/khashabiLab/bigOEIS/oeis_val_decontam.jsonl",
5
+ "overlap_source_tsv": "/Users/natebreslow/Documents/khashabiLab/bigOEIS/packed_data/oeis_val_vs_synth_aug0_inv_len_13245370099_seed0.matches_with_ids.tsv",
6
+ "excluded_rows": 73,
7
+ "kept_rows": 19035,
8
+ "indexing": "needle_seq_id is 0-based and jsonl_line_1based = needle_seq_id + 1"
9
+ }
oeis_val_neo_excluded_exact_packed_matches.tsv ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ needle_seq_id jsonl_line_1based oeis_id match_pairs first_haystack_seq_id packed_content_tokens terms first_12_terms last_3_terms
2
+ 545 546 A167176 2 15370293 209 105 [0, 1, 8, 0, 1, 8, 0, 1, 8, 0, 1, 8] [0, 1, 8]
3
+ 812 813 A368056 11 5322373 13 6 [1, 2, 4, 8, 16, 17] [8, 16, 17]
4
+ 892 893 A316981 7 4912247 17 7 [1, 1, 2, 6, 15, 40, 121] [15, 40, 121]
5
+ 938 939 A000811 1 55273515 153 7 [16, 160, 13056, 183305216, 153746461690757120, 472614400151965797795362900461748224, 22974620880419880070513913016918297106276186594558904213019777370224066560] [153746461690757120, 472614400151965797795362900461748224, 22974620880419880070513913016918297106276186594558904213019777370224066560]
6
+ 984 985 A166486 1 14419403 201 101 [0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1] [1, 1, 0]
7
+ 1418 1419 A044795 1 6999629 182 39 [82, 182, 282, 382, 482, 582, 682, 782, 829, 882, 982, 1082] [3382, 3482, 3582]
8
+ 1561 1562 A273720 1 48019539 224 28 [1, 3, 8, 21, 57, 162, 479, 1458, 4528, 14259, 45349, 145289] [2403295565913, 7966021263923, 26425616887971]
9
+ 1870 1871 A044758 1 41462953 182 39 [45, 145, 245, 345, 445, 459, 545, 645, 745, 845, 945, 1045] [3345, 3445, 3459]
10
+ 2152 2153 A249541 1 23907330 55 9 [3, 4, 5, 17, 257, 65537, 83623937, 4294967297, 6992962672132097] [83623937, 4294967297, 6992962672132097]
11
+ 2375 2376 A279069 1 57156 142 56 [1, 2, 2, 4, 16, 4, 2, 18, 4, 4, 6, 12] [6, 22, 4]
12
+ 2429 2430 A010699 1 67423087 161 81 [2, 9, 2, 9, 2, 9, 2, 9, 2, 9, 2, 9] [2, 9, 2]
13
+ 2537 2538 A048058 2 13019375 214 51 [11, 13, 17, 23, 31, 41, 53, 67, 83, 101, 121, 143] [2363, 2461, 2561]
14
+ 3059 3060 A051637 28 1481960 10 5 [1, 2, 3, 7, 10] [3, 7, 10]
15
+ 3298 3299 A179101 5 17275932 15 7 [2, 3, 5, 6, 8, 11, 14] [8, 11, 14]
16
+ 3312 3313 A084528 1 23488569 24 8 [1, 2, 5, 17, 59, 201, 703, 2405] [201, 703, 2405]
17
+ 3440 3441 A122059 3 4681413 17 9 [1, 0, 0, 1, 1, 2, 3, 0, 4] [3, 0, 4]
18
+ 3507 3508 A047578 1 9038770 194 62 [2, 5, 6, 7, 10, 13, 14, 15, 18, 21, 22, 23] [119, 122, 125]
19
+ 4130 4131 A228407 1 50962227 192 59 [0, 11, 1, 10, 100, 12, 2, 20, 101, 22, 3, 13] [134, 143, 314]
20
+ 4236 4237 A063160 4 13631973 193 48 [10, 33, 57, 81, 105, 129, 153, 177, 201, 225, 249, 273] [1089, 1113, 1137]
21
+ 4572 4573 A009836 2 58624055 170 14 [0, 2, 8, 368, 16512, 1583104, 199552000, 36445579264, 8620299812864, 2621816292114432, 987354046567284736, 452732308336619290624] [452732308336619290624, 247917555997339251900416, 159904531672039230122491904]
22
+ 4682 4683 A145039 1 5184810 98 18 [3, 7, 19, 31, 107, 127, 607, 1279, 2203, 4423, 86243, 110503] [20996011, 24036583, 25964951]
23
+ 5618 5619 A263858 1 45502052 36 17 [1, 1, 1, 1, 1, 3, 1, 1, 7, 6, 2, 1] [18, 4, 2]
24
+ 6137 6138 A088689 2 17024337 209 105 [0, 1, 1, 0, 2, 2, 0, 1, 1, 0, 2, 2] [0, 1, 1]
25
+ 7257 7258 A044615 1 29890298 186 41 [47, 111, 175, 239, 303, 367, 383, 431, 495, 559, 623, 687] [2223, 2287, 2351]
26
+ 7653 7654 A062744 1 41839518 135 15 [1, 1, 10, 145, 2470, 46060, 910252, 18730855, 397089550, 8612835715, 190223180840, 4263421511271] [96723482198980, 2216905597676000, 51256802757808320]
27
+ 8155 8156 A270840 1 58982447 204 26 [12, 36, 138, 270, 546, 4800, 7560, 12840, 14700, 358200, 678480, 16139970] [5088964800, 6974736756, 9214178820]
28
+ 8251 8252 A155645 1 1385507 193 20 [1, 12, 84, 558, 3696, 24582, 164304, 1103478, 7444416, 50431302, 342941424, 2340123798] [249557173431942, 1729973554578864, 12008254925383638]
29
+ 8279 8280 A290453 4 29575809 181 91 [1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 2] [0, 1, 0]
30
+ 8477 8478 A044402 3 19293444 182 39 [70, 170, 270, 370, 470, 570, 670, 700, 770, 870, 970, 1070] [3370, 3470, 3570]
31
+ 9025 9026 A068987 1 62874440 85 11 [2, 149, 1925, 13808, 49703, 2458886, 9470345, 186557267, 523551503, 191278379840, 4368196101672] [523551503, 191278379840, 4368196101672]
32
+ 9707 9708 A212167 3 33046 191 66 [1, 2, 3, 5, 6, 7, 10, 11, 12, 13, 14, 15] [79, 82, 83]
33
+ 9721 9722 A293077 3 14353436 199 35 [2, 4, 6, 10, 16, 26, 44, 74, 126, 214, 364, 620] [46697496, 79717612, 136086476]
34
+ 9894 9895 A009705 5 24197587 164 14 [0, 2, 4, 302, 8104, 947642, 86855404, 16203909542, 3130092938704, 896924477276402, 290713861720990804, 121467176505314129822] [121467176505314129822, 58492863120535523766904, 33925794100542844193202602]
35
+ 10305 10306 A135839 2 3672765 181 91 [1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0] [1, 0, 1]
36
+ 10600 10601 A101587 2 594093 59 14 [2, 4, 5, 20, 308, 815, 857, 1114, 1418, 1688, 12008, 18692] [18692, 28097, 90964]
37
+ 10681 10682 A024036 3 21353804 219 25 [0, 3, 15, 63, 255, 1023, 4095, 16383, 65535, 262143, 1048575, 4194303] [17592186044415, 70368744177663, 281474976710655]
38
+ 10706 10707 A100285 1 18949167 179 90 [1, 1, 5, 5, 1, 1, 5, 5, 1, 1, 5, 5] [5, 1, 1]
39
+ 11007 11008 A055843 1 3608436 231 31 [1, 13, 85, 385, 1375, 4147, 11011, 26455, 58630, 121550, 238238, 445094] [406833460, 536222500, 700950052]
40
+ 11079 11080 A017629 2 52873588 202 53 [9, 21, 33, 45, 57, 69, 81, 93, 105, 117, 129, 141] [609, 621, 633]
41
+ 11234 11235 A007520 1 28251297 202 51 [3, 11, 19, 43, 59, 67, 83, 107, 131, 139, 163, 179] [1163, 1171, 1187]
42
+ 11307 11308 A299729 1 7973728 201 54 [12, 24, 30, 36, 40, 48, 60, 63, 70, 72, 80, 84] [320, 324, 325]
43
+ 11450 11451 A010126 1 52767868 161 81 [4, 1, 2, 4, 2, 1, 8, 1, 2, 4, 2, 1] [8, 1, 2]
44
+ 12022 12023 A380991 9 1893358 9 4 [1, 4, 15, 48] [4, 15, 48]
45
+ 12091 12092 A047612 2 10789238 196 63 [0, 2, 4, 5, 8, 10, 12, 13, 16, 18, 20, 21] [120, 122, 124]
46
+ 12305 12306 A063144 1 8491582 189 49 [8, 27, 47, 67, 87, 107, 127, 147, 167, 187, 207, 227] [927, 947, 967]
47
+ 12646 12647 A129756 1 3311430 207 76 [1, 1, 1, 1, 3, 3, 3, 3, 5, 5, 5, 5] [37, 37, 37]
48
+ 13501 13502 A050621 1 20013751 188 18 [2, 12, 104, 1008, 10016, 100032, 1000064, 10000128, 100000256, 1000000512, 10000001024, 100000002048] [1000000000032768, 10000000000065536, 100000000000131072]
49
+ 13614 13615 A126048 3 11143698 95 48 [2, 3, 5, 7, 5, 1, 3, 7, 5, 1, 3, 7] [1, 1, 1]
50
+ 13703 13704 A163471 2 46480945 194 20 [3, 18, 114, 744, 4932, 32952, 221016, 1485216, 9989808, 67223328, 452457504, 3045661824] [283481998053888, 1908413999583744, 12847536038651904]
51
+ 13704 13705 A309224 3 10124493 99 35 [1, 1, 0, 0, 2, 1, 1, 2, 2, 3, 5, 7] [185, 202, 222]
52
+ 13737 13738 A101540 1 33989933 43 10 [4, 56, 500, 514, 626, 640, 724, 53110, 65330, 109672] [53110, 65330, 109672]
53
+ 14074 14075 A160124 1 19649603 206 57 [0, 0, 0, 2, 4, 4, 8, 18, 24, 24, 28, 36] [932, 1000, 1028]
54
+ 14081 14082 A010704 2 33797649 161 81 [3, 6, 3, 6, 3, 6, 3, 6, 3, 6, 3, 6] [3, 6, 3]
55
+ 14379 14380 A327424 2 605474 23 8 [1, 1, 2, 4, 10, 33, 234, 16579] [33, 234, 16579]
56
+ 14391 14392 A212611 2 5704190 167 24 [135, 3375, 25137, 68607, 22113, 557375, 451737, 1278316, 273, 6739509, 188325, 8775] [366776, 1487640245, 417725]
57
+ 14435 14436 A139302 5 8333360 100 4 [2, 512, 144115188075855872, 904625697166532776746648320380374280103671755200316906558262375061821325312] [512, 144115188075855872, 904625697166532776746648320380374280103671755200316906558262375061821325312]
58
+ 14960 14961 A057556 1 65334021 335 168 [0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0] [5, 0, 0]
59
+ 15368 15369 A056263 1 51823997 59 13 [1, 3, 27, 155, 321, 351, 1211, 1283, 7983, 15191, 84771, 119929] [84771, 119929, 148859]
60
+ 15953 15954 A007097 1 484792 261 24 [1, 2, 3, 5, 11, 31, 127, 709, 5381, 52711, 648391, 9737333] [9332039515881088707361, 499720579610303128776791, 28785866289100396890228041]
61
+ 16105 16106 A082705 4 11152753 23 6 [7, 9, 895, 1525, 3037, 21157] [1525, 3037, 21157]
62
+ 16931 16932 A255771 53 771418 21 11 [1, 1, 1, 2, 2, 1, 2, 2, 4, 2, 2] [4, 2, 2]
63
+ 17007 17008 A389124 2 37692541 230 29 [1, 3, 6, 17, 46, 133, 386, 1137, 3366, 10013, 29866, 89257] [1270955283786, 3812843481737, 11438485705966]
64
+ 17160 17161 A117995 1 52396131 207 51 [0, 0, 1, 1, 2, 3, 4, 6, 8, 11, 14, 20] [26252, 30700, 35717]
65
+ 17170 17171 A271449 3 5003778 13 6 [0, 3, 6, 9, 12, 18] [9, 12, 18]
66
+ 17255 17256 A176393 1 8944574 205 61 [1, 3, 9, 13, 17, 19, 21, 25, 29, 31, 33, 37] [161, 163, 165]
67
+ 17300 17301 A085070 11 9630624 8 4 [1, 2, 9, 26] [2, 9, 26]
68
+ 17347 17348 A295745 1 37424184 8 3 [18, 20, 24] [18, 20, 24]
69
+ 18185 18186 A380287 4 8734831 113 18 [4, 6, 16, 48, 142, 472, 1670, 6364, 24604, 97668, 390070, 1570560] [419930444, 1701635046, 6898183050]
70
+ 18227 18228 A133384 1 15280426 227 19 [12, 102, 1002, 10002, 100002, 1000002, 10000002, 100000002, 1000000002, 10000000002, 100000000002, 1000000000002] [100000000000000002, 1000000000000000002, 10000000000000000002]
71
+ 18258 18259 A011803 2 34513693 58 10 [1, 2, 8, 64, 904, 20926, 753994, 40412530, 3099627142, 329518779600] [40412530, 3099627142, 329518779600]
72
+ 18304 18305 A137444 2 13209959 217 41 [1, 4, 6, 4, -4, -16, -24, -16, 16, 64, 96, 64] [-1572864, -1048576, 1048576]
73
+ 18559 18560 A083318 1 45434387 198 32 [1, 3, 5, 9, 17, 33, 65, 129, 257, 513, 1025, 2049] [536870913, 1073741825, 2147483649]
74
+ 18605 18606 A141044 2 3030285 209 105 [2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] [1, 1, 1]