Text Generation
Transformers
Safetensors
MLX
qwen3
oeis
integer-sequences
causal-lm
text-generation-inference
Instructions to use N8Programs/NextTerm-440M with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use N8Programs/NextTerm-440M with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="N8Programs/NextTerm-440M")# Load model directly from transformers import AutoTokenizer, AutoModelForCausalLM tokenizer = AutoTokenizer.from_pretrained("N8Programs/NextTerm-440M") model = AutoModelForCausalLM.from_pretrained("N8Programs/NextTerm-440M") - MLX
How to use N8Programs/NextTerm-440M with MLX:
# Make sure mlx-lm is installed # pip install --upgrade mlx-lm # if on a CUDA device, also pip install mlx[cuda] # Generate text with mlx-lm from mlx_lm import load, generate model, tokenizer = load("N8Programs/NextTerm-440M") prompt = "Once upon a time in" text = generate(model, tokenizer, prompt=prompt, verbose=True) - Inference
- Notebooks
- Google Colab
- Kaggle
- Local Apps
- LM Studio
- vLLM
How to use N8Programs/NextTerm-440M with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "N8Programs/NextTerm-440M" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "N8Programs/NextTerm-440M", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/N8Programs/NextTerm-440M
- SGLang
How to use N8Programs/NextTerm-440M with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "N8Programs/NextTerm-440M" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "N8Programs/NextTerm-440M", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "N8Programs/NextTerm-440M" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "N8Programs/NextTerm-440M", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - MLX LM
How to use N8Programs/NextTerm-440M with MLX LM:
Generate or start a chat session
# Install MLX LM uv tool install mlx-lm # Generate some text mlx_lm.generate --model "N8Programs/NextTerm-440M" --prompt "Once upon a time"
- Docker Model Runner
How to use N8Programs/NextTerm-440M with Docker Model Runner:
docker model run hf.co/N8Programs/NextTerm-440M
Add model card and evaluation utilities
Browse files- .gitattributes +1 -0
- README.md +176 -0
- arithmetic_eval.py +263 -0
- decode_packed_oeis.py +185 -0
- eval_m1_competition_mape_mlx.py +425 -0
- eval_results.txt +42 -0
- oeis_eval_mlx_neo.py +310 -0
- radar_chart.png +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
radar_chart.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,3 +1,179 @@
|
|
| 1 |
---
|
| 2 |
license: cc-by-sa-4.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
license: cc-by-sa-4.0
|
| 3 |
+
library_name: transformers
|
| 4 |
+
pipeline_tag: text-generation
|
| 5 |
+
tags:
|
| 6 |
+
- oeis
|
| 7 |
+
- integer-sequences
|
| 8 |
+
- qwen3
|
| 9 |
+
- causal-lm
|
| 10 |
+
- mlx
|
| 11 |
+
datasets:
|
| 12 |
+
- N8Programs/oeis-massive
|
| 13 |
---
|
| 14 |
+
|
| 15 |
+
# NextTerm-440M
|
| 16 |
+
|
| 17 |
+
[](radar_chart.png)
|
| 18 |
+
|
| 19 |
+
## Model Summary
|
| 20 |
+
|
| 21 |
+
NextTerm-440M is a 440M parameter causal transformer trained to continue integer
|
| 22 |
+
sequences. It uses a Qwen3 architecture with a compact 16-token digit
|
| 23 |
+
vocabulary: decimal digits, negative sign, comma separator, BOS, EOS, PAD, and one unused token.
|
| 24 |
+
|
| 25 |
+
The model was trained on an extended OEIS corpus that enhanced many OEIS sequences with additional terms from b-files (supplemental appendices provided with OEIS) and then further augmented the data w/ a variety of prefix-preserving transforms empirically selected via small pilot experiments. The model was trained for 14B tokens w/ preserved sequence prefixes rather than concatenating distinct documents (as this was found to improve performance in pilot experiments).
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
NextTerm-440M improves dramatically over NextTerm-47M on long-context sequence continuation (as it was trained w/ a context length of 4096), innate OEIS knowledge, and long-range in context learning. The 47M model, however, remains ahead on very short prefixes that require simple rule induction without much context, which may be due to the 440M model's training on longer contexts and more complex sequences.
|
| 29 |
+
|
| 30 |
+
The tokenizer accepts integer sequences formatted as comma-separated values, for
|
| 31 |
+
example:
|
| 32 |
+
|
| 33 |
+
```text
|
| 34 |
+
1,-2,3,-4,
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
The tokenizer ignores characters other than digits, commas, and `-`. Digits are
|
| 38 |
+
tokenized individually, so there is no fixed integer-magnitude limit, but large
|
| 39 |
+
integers consume more context. The model was not trained on numbers with leading
|
| 40 |
+
zeros, so strings like `01,02,03,` should be treated as out of distribution.
|
| 41 |
+
|
| 42 |
+
## Training Details
|
| 43 |
+
|
| 44 |
+
| Field | Value |
|
| 45 |
+
| --- | --- |
|
| 46 |
+
| Parameters | 440,500,224 |
|
| 47 |
+
| Architecture | Qwen3-style causal LM |
|
| 48 |
+
| Layers | 28 |
|
| 49 |
+
| Hidden size | 1024 |
|
| 50 |
+
| FFN size | 3072 |
|
| 51 |
+
| Attention heads | 16 |
|
| 52 |
+
| KV heads | 8 |
|
| 53 |
+
| Vocabulary size | 16 |
|
| 54 |
+
| Training tokens | 13,999,999,995 |
|
| 55 |
+
| Sequence length cap | 4096 training tokens per sequence |
|
| 56 |
+
| Batch mode | Length-bucketed sequence batches |
|
| 57 |
+
| Optimizer | Muon/AdamW hybrid |
|
| 58 |
+
| LR schedule | Linear warmup to `1e-2` for Muon, `1e-4` for AdamW, cosine decay to 0.1x, final cooldown to 0 |
|
| 59 |
+
| Training hardware | Single H100 |
|
| 60 |
+
| Export dtype | bfloat16 |
|
| 61 |
+
|
| 62 |
+
A classic Muon/AdamW hybrid was used: Muon for 2D weight matrices and AdamW for 1D parameters and embedding matrices.
|
| 63 |
+
|
| 64 |
+
The model was trained on the following files in the `N8Programs/oeis-massive` dataset, randomly mixed:
|
| 65 |
+
|
| 66 |
+
- `oeis_train_bfile_prefix4096.packed`
|
| 67 |
+
- `oeis_synth_aug0_inv_len_13245370099_seed0.packed`
|
| 68 |
+
|
| 69 |
+
## Evaluation Results
|
| 70 |
+
|
| 71 |
+
### Main Benchmarks
|
| 72 |
+
|
| 73 |
+
| Model | OEIS-Eval-Neo | Ryskina & Knight | M1 Competition 111 MAPE |
|
| 74 |
+
| --- | ---: | ---: | ---: |
|
| 75 |
+
| **NextTerm-440M** | **34.43%** | 52.63% | **17.6239** |
|
| 76 |
+
| NextTerm-47M | 29.49% | **70.18%** | 18.7621 |
|
| 77 |
+
| Qwen3-0.6B | 18.44% | 33.33% | 22.7984 |
|
| 78 |
+
| Qwen3-1.7B | 20.77% | 49.12% | 22.2411 |
|
| 79 |
+
| Qwen3-4B | 23.74% | 63.16% | 19.1731 |
|
| 80 |
+
| Qwen3-8B | 24.62% | 57.89% | 18.4027 |
|
| 81 |
+
| Qwen3-14B | 26.00% | 59.65% | 17.9837 |
|
| 82 |
+
|
| 83 |
+
OEIS-Eval-Neo is a decontaminated held-out OEIS next-term evaluation. M1
|
| 84 |
+
Competition 111 reports macro MAPE, where lower is better. Ryskina & Knight
|
| 85 |
+
(2021) is a 57-sequence next-term benchmark based on psychometrics and puzzles. Note that the 47M model's strong performance on Ryskina & Knight is indicative of its strength on short-prefix sequences and rule induction.
|
| 86 |
+
|
| 87 |
+
### Polynomial Continuation
|
| 88 |
+
|
| 89 |
+
The polynomial continuation evaluation samples integer sequences from
|
| 90 |
+
polynomials of degree 1 through 4 and asks for the next term. Accuracy is exact
|
| 91 |
+
match across 200 samples for each prompt length `k`.
|
| 92 |
+
|
| 93 |
+
| Model | Arithmetic | Quadratic | Cubic | Quartic |
|
| 94 |
+
| --- | ---: | ---: | ---: | ---: |
|
| 95 |
+
| **NextTerm-440M** | 94.38% | **86.39%** | **75.20%** | **67.83%** |
|
| 96 |
+
| NextTerm-47M | 94.15% | 81.07% | 37.43% | 15.17% |
|
| 97 |
+
| Qwen3-0.6B | 90.31% | 8.72% | 0.30% | 0.02% |
|
| 98 |
+
| Qwen3-1.7B | 93.10% | 41.57% | 5.36% | 0.71% |
|
| 99 |
+
| Qwen3-4B | 93.90% | 77.26% | 28.18% | 5.98% |
|
| 100 |
+
| Qwen3-8B | **96.10%** | 80.59% | 32.93% | 7.95% |
|
| 101 |
+
| Qwen3-14B | 95.60% | 84.61% | 49.16% | 14.98% |
|
| 102 |
+
|
| 103 |
+
## Usage
|
| 104 |
+
|
| 105 |
+
### MLX
|
| 106 |
+
|
| 107 |
+
```bash
|
| 108 |
+
mlx_lm.generate --model N8Programs/NextTerm-440M --prompt "1,2,3,"
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
### Hugging Face Transformers
|
| 112 |
+
|
| 113 |
+
```python
|
| 114 |
+
import torch
|
| 115 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 116 |
+
|
| 117 |
+
model_name = "N8Programs/NextTerm-440M"
|
| 118 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 119 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 120 |
+
model_name,
|
| 121 |
+
torch_dtype=torch.bfloat16,
|
| 122 |
+
device_map="auto",
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
prompt = "1,2,3,4,5,"
|
| 126 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 127 |
+
|
| 128 |
+
outputs = model.generate(
|
| 129 |
+
**inputs,
|
| 130 |
+
max_new_tokens=64,
|
| 131 |
+
do_sample=False,
|
| 132 |
+
eos_token_id=[tokenizer.convert_tokens_to_ids(","), tokenizer.eos_token_id],
|
| 133 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
For strict next-term evaluation, stop generation on *comma or EOS* and parse the
|
| 140 |
+
text before the first comma as the predicted integer.
|
| 141 |
+
|
| 142 |
+
## Reproducibility
|
| 143 |
+
|
| 144 |
+
This repository contains the local evaluation scripts and artifacts used for the
|
| 145 |
+
results above, including:
|
| 146 |
+
|
| 147 |
+
- `oeis_eval_mlx_neo.py` for OEIS-Eval-Neo with MLX batch generation.
|
| 148 |
+
- `arithmetic_eval.py` for arithmetic/quadratic/cubic/quartic continuation.
|
| 149 |
+
- `eval_m1_competition_mape_mlx.py` for M1 Competition 111 MAPE.
|
| 150 |
+
- `eval_results.txt` for the compact result table.
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
The last three training checkpoints are available separately at
|
| 154 |
+
[N8Programs/NextTerm-440M-Checkpoints](https://huggingface.co/N8Programs/NextTerm-440M-Checkpoints).
|
| 155 |
+
The released `final_latest` checkpoint was trained for 14B tokens. Additionally, the checkpoint corresponding to the best val loss is available as well (although it is not included in the main results table as it was inferior on downstream eval performance).
|
| 156 |
+
|
| 157 |
+
The `.packed` files used for training are binary files containing the tokenized and augmented OEIS data - w/ tokens encoded as nibbles. A dedicate decoder is provided in this repo as `decode_packed_oeis.py`.
|
| 158 |
+
|
| 159 |
+
## Citation
|
| 160 |
+
|
| 161 |
+
```bibtex
|
| 162 |
+
@misc{nextterm440m2026,
|
| 163 |
+
author = {Nathan Breslow},
|
| 164 |
+
title = {NextTerm-440M: A Pretrained Transformer for Integer Sequence Prediction},
|
| 165 |
+
year = {2026},
|
| 166 |
+
publisher = {Hugging Face},
|
| 167 |
+
howpublished = {\url{https://huggingface.co/N8Programs/NextTerm-440M}},
|
| 168 |
+
note = {440.5M parameter model trained on augmented OEIS data}
|
| 169 |
+
}
|
| 170 |
+
```
|
| 171 |
+
|
| 172 |
+
## Attribution
|
| 173 |
+
|
| 174 |
+
This model and dataset were trained and created using data from the
|
| 175 |
+
**On-Line Encyclopedia of Integer Sequences (OEIS)**.
|
| 176 |
+
|
| 177 |
+
- Source: https://oeis.org/
|
| 178 |
+
- License: Creative Commons Attribution-ShareAlike 4.0 (CC BY-SA 4.0)
|
| 179 |
+
- OEIS End-User License Agreement: https://oeis.org/wiki/The_OEIS_End-User_License_Agreement
|
arithmetic_eval.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluate the NextTerm model on synthetic polynomial sequences (arithmetic through
|
| 3 |
+
quartic) with varying amounts of prior context.
|
| 4 |
+
|
| 5 |
+
For each prior term count (default: 2-25 for linear, 3-25 for quadratic,
|
| 6 |
+
4-25 for cubic, 5-25 for quartic), we sample sequences from
|
| 7 |
+
ax^4 + bx^3 + cx^2 + dx + e with coefficient ranges:
|
| 8 |
+
e in [0, 99999]
|
| 9 |
+
d in [0, 9999]
|
| 10 |
+
c in [0, 999]
|
| 11 |
+
b in [0, 99]
|
| 12 |
+
a in [0, 9]
|
| 13 |
+
|
| 14 |
+
Higher-order coefficients are set to 0 for lower-degree tests (e.g., arithmetic
|
| 15 |
+
sequences only use d and e).
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import csv
|
| 20 |
+
import random
|
| 21 |
+
from dataclasses import dataclass
|
| 22 |
+
from typing import Dict, List, Tuple
|
| 23 |
+
|
| 24 |
+
from mlx_lm import load
|
| 25 |
+
from mlx_lm.generate import BatchGenerator
|
| 26 |
+
from mlx_lm.tuner.utils import print_trainable_parameters
|
| 27 |
+
from tqdm import tqdm
|
| 28 |
+
|
| 29 |
+
# Step 335000
|
| 30 |
+
MODEL_NAME = "NextTerm-47M"
|
| 31 |
+
MAX_NEW_TOKENS = 20
|
| 32 |
+
DEFAULT_MAX_TERMS = 25
|
| 33 |
+
SAMPLES_PER_K = 200
|
| 34 |
+
RANDOM_SEED = 0
|
| 35 |
+
|
| 36 |
+
COEFF_MAX = {"a": 9, "b": 99, "c": 999, "d": 9999, "e": 99999}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass(frozen=True)
|
| 40 |
+
class PolynomialTask:
|
| 41 |
+
name: str
|
| 42 |
+
degree: int
|
| 43 |
+
min_terms: int
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
TASKS = (
|
| 47 |
+
PolynomialTask("arithmetic", 1, 2),
|
| 48 |
+
PolynomialTask("quadratic", 2, 3),
|
| 49 |
+
PolynomialTask("cubic", 3, 4),
|
| 50 |
+
PolynomialTask("quartic", 4, 5),
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def parse_generated(text: str):
|
| 55 |
+
if "," in text:
|
| 56 |
+
text = text.split(",")[0]
|
| 57 |
+
try:
|
| 58 |
+
return int(text)
|
| 59 |
+
except ValueError:
|
| 60 |
+
print(f"Could not parse generated text: {text!r}")
|
| 61 |
+
return None
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def polynomial_value(coeffs: Tuple[int, int, int, int, int], x: int) -> int:
|
| 65 |
+
a, b, c, d, e = coeffs
|
| 66 |
+
return (
|
| 67 |
+
a * (x**4)
|
| 68 |
+
+ b * (x**3)
|
| 69 |
+
+ c * (x**2)
|
| 70 |
+
+ d * x
|
| 71 |
+
+ e
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def sample_coefficients(
|
| 76 |
+
degree: int, rng: random.Random
|
| 77 |
+
) -> Tuple[int, int, int, int, int]:
|
| 78 |
+
a = rng.randint(0, COEFF_MAX["a"]) if degree >= 4 else 0
|
| 79 |
+
b = rng.randint(0, COEFF_MAX["b"]) if degree >= 3 else 0
|
| 80 |
+
c = rng.randint(0, COEFF_MAX["c"]) if degree >= 2 else 0
|
| 81 |
+
d = rng.randint(0, COEFF_MAX["d"]) if degree >= 1 else 0
|
| 82 |
+
e = rng.randint(0, COEFF_MAX["e"])
|
| 83 |
+
return a, b, c, d, e
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def generate_polynomial_sequences(
|
| 87 |
+
task: PolynomialTask, max_terms: int, samples_per_k: int, rng: random.Random
|
| 88 |
+
) -> List[Dict[str, int | List[int] | str]]:
|
| 89 |
+
sequences = []
|
| 90 |
+
for k in range(task.min_terms, max_terms + 1):
|
| 91 |
+
for _ in range(samples_per_k):
|
| 92 |
+
coeffs = sample_coefficients(task.degree, rng)
|
| 93 |
+
seq = [polynomial_value(coeffs, i) for i in range(k + 1)]
|
| 94 |
+
sequences.append(
|
| 95 |
+
{"k": k, "prompt": seq[:-1], "answer": seq[-1], "task": task.name}
|
| 96 |
+
)
|
| 97 |
+
return sequences
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def build_prompts(
|
| 101 |
+
sequences: List[Dict[str, int | List[int] | str]], tokenizer
|
| 102 |
+
) -> Tuple[List[List[int]], List[Dict[str, int | List[int] | str]]]:
|
| 103 |
+
prompts = []
|
| 104 |
+
metadata = []
|
| 105 |
+
for record in sequences:
|
| 106 |
+
prompt_text = ",".join(str(x) for x in record["prompt"]) + ","
|
| 107 |
+
prompts.append(tokenizer.encode(prompt_text))
|
| 108 |
+
metadata.append(record)
|
| 109 |
+
return prompts, metadata
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def evaluate_task(
|
| 113 |
+
task: PolynomialTask,
|
| 114 |
+
args,
|
| 115 |
+
model,
|
| 116 |
+
tokenizer,
|
| 117 |
+
sep_token: int,
|
| 118 |
+
) -> Tuple[Dict[int, Dict[str, int]], List[Dict[str, object]]]:
|
| 119 |
+
rng = random.Random(args.seed + task.degree)
|
| 120 |
+
sequences = generate_polynomial_sequences(
|
| 121 |
+
task, args.max_terms, args.samples_per_k, rng
|
| 122 |
+
)
|
| 123 |
+
if not sequences:
|
| 124 |
+
print(
|
| 125 |
+
f"Skipping {task.name}: min_terms {task.min_terms} exceeds max_terms "
|
| 126 |
+
f"{args.max_terms}"
|
| 127 |
+
)
|
| 128 |
+
return {}, []
|
| 129 |
+
|
| 130 |
+
print(
|
| 131 |
+
f"\nEvaluating {task.name} sequences (degree {task.degree}) with "
|
| 132 |
+
f"{len(sequences)} samples ({args.samples_per_k} per prior-term count)"
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
prompts, metadata = build_prompts(sequences, tokenizer)
|
| 136 |
+
|
| 137 |
+
gen = BatchGenerator(model, stop_tokens=[[sep_token], [tokenizer.eos_token_id]])
|
| 138 |
+
uids = gen.insert(prompts, [args.max_new_tokens] * len(prompts))
|
| 139 |
+
|
| 140 |
+
meta_by_uid = dict(zip(uids, metadata))
|
| 141 |
+
results = {uid: [] for uid in uids}
|
| 142 |
+
finished = set()
|
| 143 |
+
with tqdm(total=len(uids), desc=f"{task.name.title()}") as pbar:
|
| 144 |
+
while True:
|
| 145 |
+
responses = gen.next()
|
| 146 |
+
if not isinstance(responses, tuple) or len(responses) != 2:
|
| 147 |
+
raise RuntimeError(
|
| 148 |
+
"Unexpected mlx_lm BatchGenerator.next() API. "
|
| 149 |
+
"Update your mlx_lm version"
|
| 150 |
+
)
|
| 151 |
+
prompt_responses, generation_responses = responses
|
| 152 |
+
if not prompt_responses and not generation_responses:
|
| 153 |
+
break
|
| 154 |
+
if not generation_responses:
|
| 155 |
+
continue
|
| 156 |
+
for r in generation_responses:
|
| 157 |
+
results[r.uid].append(r.token)
|
| 158 |
+
if r.finish_reason == "stop" and r.uid not in finished:
|
| 159 |
+
finished.add(r.uid)
|
| 160 |
+
pbar.update(1)
|
| 161 |
+
gen.close()
|
| 162 |
+
|
| 163 |
+
stats = {
|
| 164 |
+
k: {"correct": 0, "parsed": 0, "total": 0}
|
| 165 |
+
for k in range(task.min_terms, args.max_terms + 1)
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
for uid in uids:
|
| 169 |
+
tokens = results[uid]
|
| 170 |
+
text = tokenizer.decode(tokens[:-1]) if tokens else ""
|
| 171 |
+
prediction = parse_generated(text)
|
| 172 |
+
|
| 173 |
+
meta = meta_by_uid[uid]
|
| 174 |
+
k = meta["k"]
|
| 175 |
+
answer = meta["answer"]
|
| 176 |
+
|
| 177 |
+
stats[k]["total"] += 1
|
| 178 |
+
if prediction is not None:
|
| 179 |
+
stats[k]["parsed"] += 1
|
| 180 |
+
if prediction == answer:
|
| 181 |
+
stats[k]["correct"] += 1
|
| 182 |
+
|
| 183 |
+
print("Accuracy by prior term count:")
|
| 184 |
+
csv_rows: List[Dict[str, object]] = []
|
| 185 |
+
for k in range(task.min_terms, args.max_terms + 1):
|
| 186 |
+
s = stats[k]
|
| 187 |
+
accuracy = s["correct"] / s["total"] if s["total"] else 0.0
|
| 188 |
+
print(
|
| 189 |
+
f"{k} terms: parsed {s['parsed']}/{s['total']} "
|
| 190 |
+
f"accuracy {s['correct']}/{s['total']} = {accuracy:.4f}"
|
| 191 |
+
)
|
| 192 |
+
csv_rows.append(
|
| 193 |
+
{
|
| 194 |
+
"task": task.name,
|
| 195 |
+
"degree": task.degree,
|
| 196 |
+
"k": k,
|
| 197 |
+
"parsed": s["parsed"],
|
| 198 |
+
"total": s["total"],
|
| 199 |
+
"correct": s["correct"],
|
| 200 |
+
"accuracy": accuracy,
|
| 201 |
+
}
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
overall_total = sum(s["total"] for s in stats.values())
|
| 205 |
+
overall_parsed = sum(s["parsed"] for s in stats.values())
|
| 206 |
+
overall_correct = sum(s["correct"] for s in stats.values())
|
| 207 |
+
overall_accuracy = overall_correct / overall_total if overall_total else 0.0
|
| 208 |
+
print(
|
| 209 |
+
f"Overall {task.name}: parsed {overall_parsed}/{overall_total} "
|
| 210 |
+
f"accuracy {overall_correct}/{overall_total} = {overall_accuracy:.4f}"
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
return stats, csv_rows
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def write_csv(rows: List[Dict[str, object]], path: str) -> None:
|
| 217 |
+
if not rows:
|
| 218 |
+
print("No results to write to CSV.")
|
| 219 |
+
return
|
| 220 |
+
|
| 221 |
+
fieldnames = ["task", "degree", "k", "parsed", "total", "correct", "accuracy"]
|
| 222 |
+
with open(path, "w", newline="") as f:
|
| 223 |
+
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
| 224 |
+
writer.writeheader()
|
| 225 |
+
writer.writerows(rows)
|
| 226 |
+
print(f"Wrote results to {path}")
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def evaluate(args):
|
| 230 |
+
random.seed(args.seed)
|
| 231 |
+
|
| 232 |
+
model, tokenizer = load(args.model_name)
|
| 233 |
+
|
| 234 |
+
print_trainable_parameters(model)
|
| 235 |
+
|
| 236 |
+
sep_token = tokenizer.encode("1,")[-1]
|
| 237 |
+
all_rows: List[Dict[str, object]] = []
|
| 238 |
+
for task in TASKS:
|
| 239 |
+
_, rows = evaluate_task(task, args, model, tokenizer, sep_token)
|
| 240 |
+
all_rows.extend(rows)
|
| 241 |
+
|
| 242 |
+
if args.csv_path:
|
| 243 |
+
write_csv(all_rows, args.csv_path)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def parse_args():
|
| 247 |
+
parser = argparse.ArgumentParser()
|
| 248 |
+
parser.add_argument("--model_name", default=MODEL_NAME)
|
| 249 |
+
parser.add_argument("--max_new_tokens", type=int, default=MAX_NEW_TOKENS)
|
| 250 |
+
parser.add_argument("--max_terms", type=int, default=DEFAULT_MAX_TERMS)
|
| 251 |
+
parser.add_argument("--samples_per_k", type=int, default=SAMPLES_PER_K)
|
| 252 |
+
parser.add_argument("--seed", type=int, default=RANDOM_SEED)
|
| 253 |
+
parser.add_argument("--csv_path", type=str, default=None)
|
| 254 |
+
return parser.parse_args()
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def main():
|
| 258 |
+
args = parse_args()
|
| 259 |
+
evaluate(args)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
if __name__ == "__main__":
|
| 263 |
+
main()
|
decode_packed_oeis.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Small reference decoder for bigOEIS `.packed` files.
|
| 2 |
+
|
| 3 |
+
On disk, each token is stored in a 4-bit nibble:
|
| 4 |
+
|
| 5 |
+
0..9 decimal digits
|
| 6 |
+
10 term separator, i.e. comma
|
| 7 |
+
11 negative sign
|
| 8 |
+
14 final padding nibble, if the file has an odd nibble count
|
| 9 |
+
15 sequence delimiter / EOS
|
| 10 |
+
|
| 11 |
+
Note that the packed disk codes are intentionally compact and are not exactly
|
| 12 |
+
the model vocabulary ids: the model uses NEG=10 and SEP=11. Use
|
| 13 |
+
`iter_model_token_rows()` when you want rows in model-token-id space.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import json
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Iterator
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
PACKED_SEP = 10
|
| 25 |
+
PACKED_NEG = 11
|
| 26 |
+
PACKED_PAD = 14
|
| 27 |
+
PACKED_DELIM = 15
|
| 28 |
+
|
| 29 |
+
MODEL_NEG = 10
|
| 30 |
+
MODEL_SEP = 11
|
| 31 |
+
MODEL_BOS = 12
|
| 32 |
+
MODEL_EOS = 13
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def iter_packed_nibbles(path: str | Path, chunk_size: int = 8 * 1024 * 1024) -> Iterator[int]:
|
| 36 |
+
"""Yield high nibble then low nibble for every byte in `path`."""
|
| 37 |
+
with Path(path).open("rb") as f:
|
| 38 |
+
while True:
|
| 39 |
+
chunk = f.read(chunk_size)
|
| 40 |
+
if not chunk:
|
| 41 |
+
return
|
| 42 |
+
for byte in chunk:
|
| 43 |
+
yield byte >> 4
|
| 44 |
+
yield byte & 0x0F
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def iter_model_token_rows(
|
| 48 |
+
path: str | Path,
|
| 49 |
+
*,
|
| 50 |
+
include_bos_eos: bool = True,
|
| 51 |
+
max_rows: int | None = None,
|
| 52 |
+
strict: bool = True,
|
| 53 |
+
) -> Iterator[list[int]]:
|
| 54 |
+
"""Yield each packed sequence as model-token ids.
|
| 55 |
+
|
| 56 |
+
By default rows include BOS/EOS, matching `tokenize_utils.tokenize_sequence`.
|
| 57 |
+
Set `include_bos_eos=False` to get only the content tokens.
|
| 58 |
+
"""
|
| 59 |
+
row: list[int] = [MODEL_BOS] if include_bos_eos else []
|
| 60 |
+
yielded = 0
|
| 61 |
+
seen_pad = False
|
| 62 |
+
|
| 63 |
+
for nib in iter_packed_nibbles(path):
|
| 64 |
+
if nib == PACKED_PAD:
|
| 65 |
+
seen_pad = True
|
| 66 |
+
continue
|
| 67 |
+
if seen_pad:
|
| 68 |
+
if strict:
|
| 69 |
+
raise ValueError("Found non-pad nibble after final packed padding.")
|
| 70 |
+
continue
|
| 71 |
+
|
| 72 |
+
if 0 <= nib <= 9:
|
| 73 |
+
row.append(nib)
|
| 74 |
+
elif nib == PACKED_SEP:
|
| 75 |
+
row.append(MODEL_SEP)
|
| 76 |
+
elif nib == PACKED_NEG:
|
| 77 |
+
row.append(MODEL_NEG)
|
| 78 |
+
elif nib == PACKED_DELIM:
|
| 79 |
+
if include_bos_eos:
|
| 80 |
+
row.append(MODEL_EOS)
|
| 81 |
+
yield row
|
| 82 |
+
yielded += 1
|
| 83 |
+
if max_rows is not None and yielded >= max_rows:
|
| 84 |
+
return
|
| 85 |
+
row = [MODEL_BOS] if include_bos_eos else []
|
| 86 |
+
else:
|
| 87 |
+
if strict:
|
| 88 |
+
raise ValueError(f"Invalid packed nibble: {nib}")
|
| 89 |
+
|
| 90 |
+
empty = [MODEL_BOS] if include_bos_eos else []
|
| 91 |
+
if strict and row != empty:
|
| 92 |
+
raise ValueError("Packed file ended with an unterminated sequence.")
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def iter_integer_sequences(
|
| 96 |
+
path: str | Path,
|
| 97 |
+
*,
|
| 98 |
+
as_ints: bool = False,
|
| 99 |
+
max_rows: int | None = None,
|
| 100 |
+
strict: bool = True,
|
| 101 |
+
) -> Iterator[list[int] | list[str]]:
|
| 102 |
+
"""Yield decoded OEIS rows.
|
| 103 |
+
|
| 104 |
+
Values are strings by default so enormous integers round-trip exactly
|
| 105 |
+
through JSON. Pass `as_ints=True` if Python integers are more convenient.
|
| 106 |
+
"""
|
| 107 |
+
terms: list[str] = []
|
| 108 |
+
chars: list[str] = []
|
| 109 |
+
yielded = 0
|
| 110 |
+
seen_pad = False
|
| 111 |
+
|
| 112 |
+
def finish_term() -> None:
|
| 113 |
+
if chars:
|
| 114 |
+
terms.append("".join(chars))
|
| 115 |
+
chars.clear()
|
| 116 |
+
elif strict:
|
| 117 |
+
raise ValueError("Encountered an empty term in packed sequence.")
|
| 118 |
+
|
| 119 |
+
for nib in iter_packed_nibbles(path):
|
| 120 |
+
if nib == PACKED_PAD:
|
| 121 |
+
seen_pad = True
|
| 122 |
+
continue
|
| 123 |
+
if seen_pad:
|
| 124 |
+
if strict:
|
| 125 |
+
raise ValueError("Found non-pad nibble after final packed padding.")
|
| 126 |
+
continue
|
| 127 |
+
|
| 128 |
+
if 0 <= nib <= 9:
|
| 129 |
+
chars.append(str(nib))
|
| 130 |
+
elif nib == PACKED_NEG:
|
| 131 |
+
if strict and chars:
|
| 132 |
+
raise ValueError("Found a negative sign after term digits had started.")
|
| 133 |
+
chars.append("-")
|
| 134 |
+
elif nib == PACKED_SEP:
|
| 135 |
+
finish_term()
|
| 136 |
+
elif nib == PACKED_DELIM:
|
| 137 |
+
if chars:
|
| 138 |
+
finish_term()
|
| 139 |
+
row = [int(term) for term in terms] if as_ints else list(terms)
|
| 140 |
+
yield row
|
| 141 |
+
yielded += 1
|
| 142 |
+
if max_rows is not None and yielded >= max_rows:
|
| 143 |
+
return
|
| 144 |
+
terms.clear()
|
| 145 |
+
else:
|
| 146 |
+
if strict:
|
| 147 |
+
raise ValueError(f"Invalid packed nibble: {nib}")
|
| 148 |
+
|
| 149 |
+
if strict and (terms or chars):
|
| 150 |
+
raise ValueError("Packed file ended with an unterminated sequence.")
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _main() -> None:
|
| 154 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
| 155 |
+
parser.add_argument("packed_file", help="Path to a .packed file")
|
| 156 |
+
parser.add_argument("-n", "--max-rows", type=int, default=5)
|
| 157 |
+
parser.add_argument("--tokens", action="store_true", help="Print model-token rows instead of decoded terms")
|
| 158 |
+
parser.add_argument("--content-only", action="store_true", help="Omit BOS/EOS when printing token rows")
|
| 159 |
+
parser.add_argument("--ints", action="store_true", help="Emit decoded terms as JSON numbers instead of strings")
|
| 160 |
+
parser.add_argument("--no-strict", action="store_true", help="Ignore invalid/trailing data instead of raising")
|
| 161 |
+
args = parser.parse_args()
|
| 162 |
+
|
| 163 |
+
strict = not args.no_strict
|
| 164 |
+
if args.tokens:
|
| 165 |
+
rows = iter_model_token_rows(
|
| 166 |
+
args.packed_file,
|
| 167 |
+
include_bos_eos=not args.content_only,
|
| 168 |
+
max_rows=args.max_rows,
|
| 169 |
+
strict=strict,
|
| 170 |
+
)
|
| 171 |
+
for row in rows:
|
| 172 |
+
print(json.dumps({"tokens": row}, separators=(",", ":")))
|
| 173 |
+
else:
|
| 174 |
+
rows = iter_integer_sequences(
|
| 175 |
+
args.packed_file,
|
| 176 |
+
as_ints=args.ints,
|
| 177 |
+
max_rows=args.max_rows,
|
| 178 |
+
strict=strict,
|
| 179 |
+
)
|
| 180 |
+
for row in rows:
|
| 181 |
+
print(json.dumps({"seq": row}, separators=(",", ":")))
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
if __name__ == "__main__":
|
| 185 |
+
_main()
|
eval_m1_competition_mape_mlx.py
ADDED
|
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Evaluate NextTerm-style MLX models on the canonical 111-series M1 set."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import gc
|
| 8 |
+
import json
|
| 9 |
+
import math
|
| 10 |
+
import time
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from decimal import Decimal
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from statistics import mean, median
|
| 15 |
+
|
| 16 |
+
import mlx.core as mx
|
| 17 |
+
from mlx_lm import load
|
| 18 |
+
from mlx_lm.generate import BatchGenerator
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
|
| 21 |
+
from eval_m1_monthly_mape_mlx import (
|
| 22 |
+
SuppressTokenLogits,
|
| 23 |
+
context_scale,
|
| 24 |
+
mape_for_predictions,
|
| 25 |
+
parse_decimal,
|
| 26 |
+
parse_generated_terms,
|
| 27 |
+
scale_decimal_to_int,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
DATA_PATH = Path("m1_competition_111.jsonl")
|
| 32 |
+
MODEL_PATH = Path("NextTerm-440M")
|
| 33 |
+
OUTPUT_PATH = Path("m1_eval_results/m1_competition111_nextterm440m_greedy_per_series.jsonl")
|
| 34 |
+
SUMMARY_PATH = Path("m1_eval_results/m1_competition111_nextterm440m_greedy_summary.json")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class SeriesRecord:
|
| 39 |
+
row_index: int
|
| 40 |
+
series_name: str
|
| 41 |
+
original_id: str
|
| 42 |
+
period: str
|
| 43 |
+
frequency: int
|
| 44 |
+
type: str
|
| 45 |
+
description: str
|
| 46 |
+
horizon: int
|
| 47 |
+
raw_context: list[str]
|
| 48 |
+
raw_target: list[str]
|
| 49 |
+
context_values: list[Decimal]
|
| 50 |
+
target_values: list[Decimal]
|
| 51 |
+
naive2_forecast: list[Decimal]
|
| 52 |
+
naive2_mape: float
|
| 53 |
+
scale: int
|
| 54 |
+
scaled_context: list[int]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def load_m1_competition(path: Path) -> list[SeriesRecord]:
|
| 58 |
+
records: list[SeriesRecord] = []
|
| 59 |
+
with path.open("r", encoding="utf-8") as f:
|
| 60 |
+
for line in f:
|
| 61 |
+
if not line.strip():
|
| 62 |
+
continue
|
| 63 |
+
row = json.loads(line)
|
| 64 |
+
raw_context = [str(x) for x in row["context"]]
|
| 65 |
+
raw_target = [str(x) for x in row["target"]]
|
| 66 |
+
context_values = [parse_decimal(x) for x in raw_context]
|
| 67 |
+
target_values = [parse_decimal(x) for x in raw_target]
|
| 68 |
+
scale = context_scale(raw_context)
|
| 69 |
+
scaled_context = [scale_decimal_to_int(v, scale) for v in context_values]
|
| 70 |
+
records.append(
|
| 71 |
+
SeriesRecord(
|
| 72 |
+
row_index=int(row["row_index"]),
|
| 73 |
+
series_name=str(row["series_name"]),
|
| 74 |
+
original_id=str(row["original_id"]),
|
| 75 |
+
period=str(row["period"]),
|
| 76 |
+
frequency=int(row["frequency"]),
|
| 77 |
+
type=str(row["type"]),
|
| 78 |
+
description=str(row["description"]),
|
| 79 |
+
horizon=int(row["horizon"]),
|
| 80 |
+
raw_context=raw_context,
|
| 81 |
+
raw_target=raw_target,
|
| 82 |
+
context_values=context_values,
|
| 83 |
+
target_values=target_values,
|
| 84 |
+
naive2_forecast=[parse_decimal(str(x)) for x in row["naive2_forecast"]],
|
| 85 |
+
naive2_mape=float(row["naive2_mape"]),
|
| 86 |
+
scale=scale,
|
| 87 |
+
scaled_context=scaled_context,
|
| 88 |
+
)
|
| 89 |
+
)
|
| 90 |
+
return records
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def decimal_or_none(value: Decimal | None) -> str | None:
|
| 94 |
+
return str(value) if value is not None else None
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def load_completed(path: Path) -> dict[int, dict]:
|
| 98 |
+
completed: dict[int, dict] = {}
|
| 99 |
+
if not path.exists():
|
| 100 |
+
return completed
|
| 101 |
+
with path.open("r", encoding="utf-8") as f:
|
| 102 |
+
for line in f:
|
| 103 |
+
if not line.strip():
|
| 104 |
+
continue
|
| 105 |
+
record = json.loads(line)
|
| 106 |
+
completed[int(record["row_index"])] = record
|
| 107 |
+
return completed
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def aggregate_summary(records: list[dict]) -> dict:
|
| 111 |
+
def collect(key: str, rows: list[dict] = records) -> list[float]:
|
| 112 |
+
vals: list[float] = []
|
| 113 |
+
for r in rows:
|
| 114 |
+
value = r.get(key)
|
| 115 |
+
if value is not None and math.isfinite(float(value)):
|
| 116 |
+
vals.append(float(value))
|
| 117 |
+
return vals
|
| 118 |
+
|
| 119 |
+
model_series = collect("mape")
|
| 120 |
+
naive2_series = collect("naive2_mape")
|
| 121 |
+
all_model_apes = [
|
| 122 |
+
float(x)
|
| 123 |
+
for r in records
|
| 124 |
+
for x in (r.get("apes") or [])
|
| 125 |
+
if x is not None and math.isfinite(float(x))
|
| 126 |
+
]
|
| 127 |
+
all_naive2_apes = [
|
| 128 |
+
float(x)
|
| 129 |
+
for r in records
|
| 130 |
+
for x in (r.get("naive2_apes") or [])
|
| 131 |
+
if x is not None and math.isfinite(float(x))
|
| 132 |
+
]
|
| 133 |
+
|
| 134 |
+
periods = sorted({str(r["period"]) for r in records})
|
| 135 |
+
by_period = {}
|
| 136 |
+
for period in periods:
|
| 137 |
+
rows = [r for r in records if r["period"] == period]
|
| 138 |
+
period_model = collect("mape", rows)
|
| 139 |
+
period_naive2 = collect("naive2_mape", rows)
|
| 140 |
+
period_model_apes = [
|
| 141 |
+
float(x)
|
| 142 |
+
for r in rows
|
| 143 |
+
for x in (r.get("apes") or [])
|
| 144 |
+
if x is not None and math.isfinite(float(x))
|
| 145 |
+
]
|
| 146 |
+
period_naive2_apes = [
|
| 147 |
+
float(x)
|
| 148 |
+
for r in rows
|
| 149 |
+
for x in (r.get("naive2_apes") or [])
|
| 150 |
+
if x is not None and math.isfinite(float(x))
|
| 151 |
+
]
|
| 152 |
+
by_period[period] = {
|
| 153 |
+
"series_count": len(rows),
|
| 154 |
+
"model_macro_mape": mean(period_model) if period_model else None,
|
| 155 |
+
"model_point_mape": mean(period_model_apes) if period_model_apes else None,
|
| 156 |
+
"naive2_macro_mape": mean(period_naive2) if period_naive2 else None,
|
| 157 |
+
"naive2_point_mape": mean(period_naive2_apes) if period_naive2_apes else None,
|
| 158 |
+
"parsed_full_series": sum(
|
| 159 |
+
1 for r in rows if int(r.get("parsed_terms", 0)) >= int(r.get("horizon", 0))
|
| 160 |
+
),
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
horizon_counts: dict[int, int] = {}
|
| 164 |
+
for r in records:
|
| 165 |
+
h = int(r["horizon"])
|
| 166 |
+
horizon_counts[h] = horizon_counts.get(h, 0) + 1
|
| 167 |
+
|
| 168 |
+
return {
|
| 169 |
+
"series_count": len(records),
|
| 170 |
+
"horizon_counts": dict(sorted(horizon_counts.items())),
|
| 171 |
+
"model_macro_mape": mean(model_series) if model_series else None,
|
| 172 |
+
"model_median_series_mape": median(model_series) if model_series else None,
|
| 173 |
+
"model_point_mape": mean(all_model_apes) if all_model_apes else None,
|
| 174 |
+
"naive2_macro_mape": mean(naive2_series) if naive2_series else None,
|
| 175 |
+
"naive2_point_mape": mean(all_naive2_apes) if all_naive2_apes else None,
|
| 176 |
+
"parsed_full_series": sum(
|
| 177 |
+
1 for r in records if int(r.get("parsed_terms", 0)) >= int(r.get("horizon", 0))
|
| 178 |
+
),
|
| 179 |
+
"total_missing_terms": sum(int(r.get("missing_terms", 0)) for r in records),
|
| 180 |
+
"malformed_series": sum(1 for r in records if r.get("malformed", False)),
|
| 181 |
+
"by_period": by_period,
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def parse_args() -> argparse.Namespace:
|
| 186 |
+
parser = argparse.ArgumentParser()
|
| 187 |
+
parser.add_argument("--data-path", type=Path, default=DATA_PATH)
|
| 188 |
+
parser.add_argument("--model", type=Path, default=MODEL_PATH)
|
| 189 |
+
parser.add_argument("--output", type=Path, default=OUTPUT_PATH)
|
| 190 |
+
parser.add_argument("--summary-output", type=Path, default=SUMMARY_PATH)
|
| 191 |
+
parser.add_argument("--batch-size", type=int, default=64)
|
| 192 |
+
parser.add_argument("--max-new-tokens", type=int, default=1024)
|
| 193 |
+
parser.add_argument("--max-context-tokens", type=int, default=2048)
|
| 194 |
+
parser.add_argument("--max-series", type=int, default=0)
|
| 195 |
+
parser.add_argument(
|
| 196 |
+
"--allow-eos-before-horizon",
|
| 197 |
+
action="store_true",
|
| 198 |
+
help="Do not suppress EOS/pad while waiting for all forecast terms.",
|
| 199 |
+
)
|
| 200 |
+
parser.add_argument(
|
| 201 |
+
"--missing-policy",
|
| 202 |
+
choices=["zero", "last_context", "skip"],
|
| 203 |
+
default="zero",
|
| 204 |
+
)
|
| 205 |
+
parser.add_argument("--overwrite", action="store_true")
|
| 206 |
+
return parser.parse_args()
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def main() -> None:
|
| 210 |
+
args = parse_args()
|
| 211 |
+
started = time.perf_counter()
|
| 212 |
+
series = load_m1_competition(args.data_path)
|
| 213 |
+
if args.max_series > 0:
|
| 214 |
+
series = series[: args.max_series]
|
| 215 |
+
print(f"Loaded {len(series)} canonical M1 competition series from {args.data_path}")
|
| 216 |
+
print("Horizons:", {h: sum(1 for s in series if s.horizon == h) for h in sorted({s.horizon for s in series})})
|
| 217 |
+
|
| 218 |
+
model, tokenizer = load(str(args.model))
|
| 219 |
+
print(f"Loaded model: {args.model}")
|
| 220 |
+
|
| 221 |
+
prompts_text = [",".join(str(x) for x in s.scaled_context) + "," for s in series]
|
| 222 |
+
prompts = [tokenizer.encode(text) for text in prompts_text]
|
| 223 |
+
sep_token = tokenizer.encode("1,")[-1]
|
| 224 |
+
eval_indices = [
|
| 225 |
+
i
|
| 226 |
+
for i, prompt in enumerate(prompts)
|
| 227 |
+
if args.max_context_tokens <= 0 or len(prompt) < args.max_context_tokens
|
| 228 |
+
]
|
| 229 |
+
skipped_long = len(prompts) - len(eval_indices)
|
| 230 |
+
eval_indices = sorted(eval_indices, key=lambda i: len(prompts[i]))
|
| 231 |
+
print(
|
| 232 |
+
f"Evaluating {len(eval_indices)} series; skipped_long={skipped_long}; "
|
| 233 |
+
f"batch_size={args.batch_size}; max_new_tokens={args.max_new_tokens}; sep_token={sep_token}"
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
args.output.parent.mkdir(parents=True, exist_ok=True)
|
| 237 |
+
args.summary_output.parent.mkdir(parents=True, exist_ok=True)
|
| 238 |
+
if args.overwrite and args.output.exists():
|
| 239 |
+
args.output.unlink()
|
| 240 |
+
completed = load_completed(args.output)
|
| 241 |
+
if completed:
|
| 242 |
+
print(f"Resuming from {args.output}: {len(completed)} series already done")
|
| 243 |
+
todo_indices = [i for i in eval_indices if i not in completed]
|
| 244 |
+
|
| 245 |
+
stop_tokens = []
|
| 246 |
+
suppress_token_ids: list[int] = []
|
| 247 |
+
if getattr(tokenizer, "eos_token_id", None) is not None:
|
| 248 |
+
suppress_token_ids.append(tokenizer.eos_token_id)
|
| 249 |
+
if args.allow_eos_before_horizon:
|
| 250 |
+
stop_tokens.append([tokenizer.eos_token_id])
|
| 251 |
+
if getattr(tokenizer, "pad_token_id", None) is not None:
|
| 252 |
+
suppress_token_ids.append(tokenizer.pad_token_id)
|
| 253 |
+
if args.allow_eos_before_horizon:
|
| 254 |
+
stop_tokens.append([tokenizer.pad_token_id])
|
| 255 |
+
suppress_processor = None if args.allow_eos_before_horizon else SuppressTokenLogits(suppress_token_ids)
|
| 256 |
+
gen = BatchGenerator(
|
| 257 |
+
model,
|
| 258 |
+
stop_tokens=stop_tokens or None,
|
| 259 |
+
completion_batch_size=args.batch_size,
|
| 260 |
+
prefill_batch_size=args.batch_size,
|
| 261 |
+
)
|
| 262 |
+
uid_to_idx: dict[int, int] = {}
|
| 263 |
+
generated_tokens: dict[int, list[int]] = {}
|
| 264 |
+
separator_counts: dict[int, int] = {}
|
| 265 |
+
if todo_indices:
|
| 266 |
+
logits_processors = (
|
| 267 |
+
[[suppress_processor] for _ in todo_indices]
|
| 268 |
+
if suppress_processor is not None
|
| 269 |
+
else None
|
| 270 |
+
)
|
| 271 |
+
uids = gen.insert(
|
| 272 |
+
[prompts[i] for i in todo_indices],
|
| 273 |
+
[args.max_new_tokens] * len(todo_indices),
|
| 274 |
+
logits_processors=logits_processors,
|
| 275 |
+
)
|
| 276 |
+
uid_to_idx = {uid: idx for uid, idx in zip(uids, todo_indices)}
|
| 277 |
+
generated_tokens = {uid: [] for uid in uids}
|
| 278 |
+
separator_counts = {uid: 0 for uid in uids}
|
| 279 |
+
finished: set[int] = set()
|
| 280 |
+
|
| 281 |
+
try:
|
| 282 |
+
with args.output.open("a", encoding="utf-8") as out:
|
| 283 |
+
with tqdm(total=len(todo_indices), desc="Generating") as progress:
|
| 284 |
+
|
| 285 |
+
def finalize_uid(uid: int, finish_reason: str) -> None:
|
| 286 |
+
if uid in finished:
|
| 287 |
+
return
|
| 288 |
+
finished.add(uid)
|
| 289 |
+
idx = uid_to_idx[uid]
|
| 290 |
+
s = series[idx]
|
| 291 |
+
generated_text = tokenizer.decode(generated_tokens[uid])
|
| 292 |
+
scaled_terms, malformed = parse_generated_terms(generated_text, s.horizon)
|
| 293 |
+
predictions: list[Decimal | None] = [
|
| 294 |
+
Decimal(term) / Decimal(s.scale) for term in scaled_terms
|
| 295 |
+
]
|
| 296 |
+
predictions.extend([None] * (s.horizon - len(predictions)))
|
| 297 |
+
fallback = s.context_values[-1] if s.context_values else Decimal(0)
|
| 298 |
+
mape, apes, missing = mape_for_predictions(
|
| 299 |
+
s.target_values,
|
| 300 |
+
predictions,
|
| 301 |
+
missing_policy=args.missing_policy,
|
| 302 |
+
fallback=fallback,
|
| 303 |
+
)
|
| 304 |
+
naive2_mape, naive2_apes, _ = mape_for_predictions(
|
| 305 |
+
s.target_values,
|
| 306 |
+
s.naive2_forecast,
|
| 307 |
+
missing_policy="skip",
|
| 308 |
+
fallback=fallback,
|
| 309 |
+
)
|
| 310 |
+
record = {
|
| 311 |
+
"row_index": s.row_index,
|
| 312 |
+
"series_name": s.series_name,
|
| 313 |
+
"original_id": s.original_id,
|
| 314 |
+
"period": s.period,
|
| 315 |
+
"frequency": s.frequency,
|
| 316 |
+
"type": s.type,
|
| 317 |
+
"description": s.description,
|
| 318 |
+
"horizon": s.horizon,
|
| 319 |
+
"scale": s.scale,
|
| 320 |
+
"context_length": len(s.context_values),
|
| 321 |
+
"prompt_tokens": len(prompts[idx]),
|
| 322 |
+
"target": [str(x) for x in s.target_values],
|
| 323 |
+
"scaled_prediction_terms": scaled_terms,
|
| 324 |
+
"prediction": [decimal_or_none(x) for x in predictions],
|
| 325 |
+
"parsed_terms": len(scaled_terms),
|
| 326 |
+
"missing_terms": missing,
|
| 327 |
+
"malformed": malformed,
|
| 328 |
+
"mape": mape,
|
| 329 |
+
"apes": apes,
|
| 330 |
+
"naive2_forecast": [str(x) for x in s.naive2_forecast],
|
| 331 |
+
"naive2_mape": naive2_mape,
|
| 332 |
+
"naive2_apes": naive2_apes,
|
| 333 |
+
"generated_text": generated_text,
|
| 334 |
+
"finish_reason": finish_reason,
|
| 335 |
+
}
|
| 336 |
+
out.write(json.dumps(record) + "\n")
|
| 337 |
+
out.flush()
|
| 338 |
+
progress.update(1)
|
| 339 |
+
|
| 340 |
+
while todo_indices and len(finished) < len(todo_indices):
|
| 341 |
+
responses = gen.next()
|
| 342 |
+
if isinstance(responses, tuple) and len(responses) == 2:
|
| 343 |
+
prompt_responses, generation_responses = responses
|
| 344 |
+
elif isinstance(responses, list):
|
| 345 |
+
prompt_responses, generation_responses = [], responses
|
| 346 |
+
else:
|
| 347 |
+
raise RuntimeError(
|
| 348 |
+
"Unexpected mlx_lm BatchGenerator.next() API. Update mlx_lm version."
|
| 349 |
+
)
|
| 350 |
+
if not prompt_responses and not generation_responses:
|
| 351 |
+
break
|
| 352 |
+
remove_uids: list[int] = []
|
| 353 |
+
for response in generation_responses:
|
| 354 |
+
uid = response.uid
|
| 355 |
+
if uid in finished:
|
| 356 |
+
continue
|
| 357 |
+
idx = uid_to_idx[uid]
|
| 358 |
+
horizon = series[idx].horizon
|
| 359 |
+
if response.finish_reason != "stop":
|
| 360 |
+
token = int(response.token)
|
| 361 |
+
generated_tokens[uid].append(token)
|
| 362 |
+
if token == sep_token:
|
| 363 |
+
separator_counts[uid] += 1
|
| 364 |
+
if separator_counts[uid] >= horizon:
|
| 365 |
+
finalize_uid(uid, "separator_count")
|
| 366 |
+
remove_uids.append(uid)
|
| 367 |
+
continue
|
| 368 |
+
if response.finish_reason is not None:
|
| 369 |
+
finalize_uid(uid, str(response.finish_reason))
|
| 370 |
+
if remove_uids:
|
| 371 |
+
gen.remove(remove_uids)
|
| 372 |
+
if len(finished) != len(todo_indices):
|
| 373 |
+
raise RuntimeError(f"Finished {len(finished)}/{len(todo_indices)} series")
|
| 374 |
+
finally:
|
| 375 |
+
gen.close()
|
| 376 |
+
mx.clear_cache()
|
| 377 |
+
gc.collect()
|
| 378 |
+
|
| 379 |
+
eval_index_set = set(eval_indices)
|
| 380 |
+
all_records = [
|
| 381 |
+
r for idx, r in sorted(load_completed(args.output).items()) if idx in eval_index_set
|
| 382 |
+
]
|
| 383 |
+
aggregate = aggregate_summary(all_records)
|
| 384 |
+
elapsed = time.perf_counter() - started
|
| 385 |
+
summary = {
|
| 386 |
+
"data_path": str(args.data_path),
|
| 387 |
+
"model": str(args.model),
|
| 388 |
+
"output": str(args.output),
|
| 389 |
+
"series_loaded": len(series),
|
| 390 |
+
"series_evaluated": len(all_records),
|
| 391 |
+
"skipped_long": skipped_long,
|
| 392 |
+
"max_context_tokens": args.max_context_tokens,
|
| 393 |
+
"max_new_tokens": args.max_new_tokens,
|
| 394 |
+
"batch_size": args.batch_size,
|
| 395 |
+
"missing_policy": args.missing_policy,
|
| 396 |
+
"suppress_eos_until_horizon": not args.allow_eos_before_horizon,
|
| 397 |
+
"seconds": elapsed,
|
| 398 |
+
**aggregate,
|
| 399 |
+
}
|
| 400 |
+
args.summary_output.write_text(json.dumps(summary, indent=2) + "\n", encoding="utf-8")
|
| 401 |
+
print(
|
| 402 |
+
json.dumps(
|
| 403 |
+
{
|
| 404 |
+
k: summary[k]
|
| 405 |
+
for k in [
|
| 406 |
+
"series_evaluated",
|
| 407 |
+
"model_macro_mape",
|
| 408 |
+
"model_point_mape",
|
| 409 |
+
"naive2_macro_mape",
|
| 410 |
+
"naive2_point_mape",
|
| 411 |
+
"parsed_full_series",
|
| 412 |
+
"total_missing_terms",
|
| 413 |
+
"malformed_series",
|
| 414 |
+
"seconds",
|
| 415 |
+
]
|
| 416 |
+
},
|
| 417 |
+
indent=2,
|
| 418 |
+
)
|
| 419 |
+
)
|
| 420 |
+
print(f"Wrote {args.output}")
|
| 421 |
+
print(f"Wrote {args.summary_output}")
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
if __name__ == "__main__":
|
| 425 |
+
main()
|
eval_results.txt
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
OEIS-Eval-Neo:
|
| 2 |
+
NextTerm-440M - 34.43%
|
| 3 |
+
NextTerm-47M - 29.49%
|
| 4 |
+
Qwen3-0.6B - 18.44%
|
| 5 |
+
Qwen3-1.7B - 20.77%
|
| 6 |
+
Qwen3-4B - 23.74%
|
| 7 |
+
Qwen3-8B - 24.62%
|
| 8 |
+
Qwen3-14B - 26.00%
|
| 9 |
+
|
| 10 |
+
Ryskina & Knight (2021):
|
| 11 |
+
NextTerm-440M - 52.63%
|
| 12 |
+
NextTerm-47M - 70.18%
|
| 13 |
+
|
| 14 |
+
16-Shot Bitstring:
|
| 15 |
+
NextTerm-440M - 32.00%
|
| 16 |
+
NextTerm-47M - 27.88%
|
| 17 |
+
|
| 18 |
+
M1 Competition 111 MAPE (macro; canonical greedy; lower is better):
|
| 19 |
+
Naive2 - 17.7987
|
| 20 |
+
NextTerm-440M - 17.6239
|
| 21 |
+
NextTerm-47M - 18.7621
|
| 22 |
+
Qwen3-0.6B - 22.7984
|
| 23 |
+
Qwen3-1.7B - 22.2411
|
| 24 |
+
Qwen3-4B - 19.1731
|
| 25 |
+
Qwen3-8B - 18.4027
|
| 26 |
+
Qwen3-14B - 17.9837
|
| 27 |
+
|
| 28 |
+
M1 Competition 111 by frequency (macro MAPE):
|
| 29 |
+
Naive2 - monthly 17.3871 / quarterly 19.5958 / yearly 17.1314
|
| 30 |
+
NextTerm-440M - monthly 18.3407 / quarterly 19.0475 / yearly 13.5498
|
| 31 |
+
NextTerm-47M - monthly 21.2719 / quarterly 16.0270 / yearly 13.3741
|
| 32 |
+
Qwen3-0.6B - monthly 24.7585 / quarterly 21.7989 / yearly 17.2835
|
| 33 |
+
Qwen3-1.7B - monthly 24.3821 / quarterly 22.5869 / yearly 14.5642
|
| 34 |
+
Qwen3-4B - monthly 20.6455 / quarterly 19.6394 / yearly 13.6308
|
| 35 |
+
Qwen3-8B - monthly 20.0289 / quarterly 17.7249 / yearly 13.6534
|
| 36 |
+
Qwen3-14B - monthly 19.4006 / quarterly 18.1729 / yearly 12.9486
|
| 37 |
+
|
| 38 |
+
Polynomial continuation evals (accuracy; 200 samples per k):
|
| 39 |
+
Arithmetic (k=2-25): NextTerm-440M - 94.38%; NextTerm-47M - 94.15%
|
| 40 |
+
Quadratic (k=3-25): NextTerm-440M - 86.39%; NextTerm-47M - 81.07%
|
| 41 |
+
Cubic (k=4-25): NextTerm-440M - 75.20%; NextTerm-47M - 37.43%
|
| 42 |
+
Quartic (k=5-25): NextTerm-440M - 67.83%; NextTerm-47M - 15.17%
|
oeis_eval_mlx_neo.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evaluate NextTerm on OEIS Eval Neo with MLX-LM BatchGenerator."""
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import gc
|
| 5 |
+
import inspect
|
| 6 |
+
import json
|
| 7 |
+
import time
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import mlx.core as mx
|
| 11 |
+
from mlx_lm import load
|
| 12 |
+
from mlx_lm.generate import BatchGenerator
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
DATA_PATH = Path("/Users/natebreslow/Documents/khashabiLab/bigOEIS/oeis_val_neo.jsonl")
|
| 16 |
+
MODEL_NAME = "/Users/natebreslow/Documents/khashabiLab/bigOEIS/NextTerm-440M"
|
| 17 |
+
MAX_NEW_TOKENS = 196
|
| 18 |
+
MAX_CONTEXT_TOKENS = 4096
|
| 19 |
+
BATCH_SIZE = 64
|
| 20 |
+
OUTPUT_PATH = Path("oeis_eval_results/oeis_eval_mlx_neo_per_doc.jsonl")
|
| 21 |
+
SUMMARY_PATH = Path("oeis_eval_results/oeis_eval_mlx_neo_summary.json")
|
| 22 |
+
PARSE_ERROR_PRINT_LIMIT = 25
|
| 23 |
+
parse_error_print_count = 0
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def parse_generated(text: str) -> int | None:
|
| 27 |
+
global parse_error_print_count
|
| 28 |
+
if "," in text:
|
| 29 |
+
text = text.split(",")[0]
|
| 30 |
+
try:
|
| 31 |
+
return int(text)
|
| 32 |
+
except ValueError:
|
| 33 |
+
if parse_error_print_count < PARSE_ERROR_PRINT_LIMIT:
|
| 34 |
+
print(f"Could not parse generated text: {text!r}")
|
| 35 |
+
parse_error_print_count += 1
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def load_sequences(path: Path):
|
| 40 |
+
sequences = []
|
| 41 |
+
answers = []
|
| 42 |
+
with path.open() as f:
|
| 43 |
+
for line in f:
|
| 44 |
+
record = json.loads(line, parse_int=str)
|
| 45 |
+
seq = record.get("seq", [])
|
| 46 |
+
if len(seq) < 2:
|
| 47 |
+
continue
|
| 48 |
+
sequences.append(seq[:-1])
|
| 49 |
+
answers.append(str(seq[-1]))
|
| 50 |
+
return sequences, answers
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def parse_args():
|
| 54 |
+
parser = argparse.ArgumentParser()
|
| 55 |
+
parser.add_argument("--data-path", type=Path, default=DATA_PATH)
|
| 56 |
+
parser.add_argument("--model", default=MODEL_NAME)
|
| 57 |
+
parser.add_argument("--max-new-tokens", type=int, default=MAX_NEW_TOKENS)
|
| 58 |
+
parser.add_argument("--max-context-tokens", type=int, default=MAX_CONTEXT_TOKENS)
|
| 59 |
+
parser.add_argument("--batch-size", type=int, default=BATCH_SIZE)
|
| 60 |
+
parser.add_argument("--max-examples", type=int, default=0)
|
| 61 |
+
parser.add_argument("--output", type=Path, default=OUTPUT_PATH)
|
| 62 |
+
parser.add_argument("--summary-output", type=Path, default=SUMMARY_PATH)
|
| 63 |
+
parser.add_argument("--overwrite", action="store_true")
|
| 64 |
+
parser.add_argument("--restrict-digit-comma-eos", action="store_true")
|
| 65 |
+
parser.add_argument("--restrict-integer-comma-eos", action="store_true")
|
| 66 |
+
return parser.parse_args()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def load_completed(path: Path) -> dict[int, dict]:
|
| 70 |
+
completed = {}
|
| 71 |
+
if not path.exists():
|
| 72 |
+
return completed
|
| 73 |
+
with path.open("r", encoding="utf-8") as f:
|
| 74 |
+
for line in f:
|
| 75 |
+
if not line.strip():
|
| 76 |
+
continue
|
| 77 |
+
record = json.loads(line)
|
| 78 |
+
completed[int(record["row_index"])] = record
|
| 79 |
+
return completed
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def encode_no_special(tokenizer, text: str) -> list[int]:
|
| 83 |
+
try:
|
| 84 |
+
return tokenizer.encode(text, add_special_tokens=False)
|
| 85 |
+
except TypeError:
|
| 86 |
+
return tokenizer.encode(text)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def normalize_stop_tokens_for_batch_generator(stop_tokens: list[list[int]]):
|
| 90 |
+
annotation = inspect.signature(BatchGenerator.__init__).parameters[
|
| 91 |
+
"stop_tokens"
|
| 92 |
+
].annotation
|
| 93 |
+
if "set" in str(annotation):
|
| 94 |
+
return {seq[0] for seq in stop_tokens if len(seq) == 1}
|
| 95 |
+
return stop_tokens
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def split_batch_generator_responses(responses):
|
| 99 |
+
if isinstance(responses, tuple) and len(responses) == 2:
|
| 100 |
+
prompt_responses, generation_responses = responses
|
| 101 |
+
return prompt_responses, generation_responses
|
| 102 |
+
if isinstance(responses, list):
|
| 103 |
+
return [], responses
|
| 104 |
+
raise RuntimeError(
|
| 105 |
+
"Unexpected mlx_lm BatchGenerator.next() API. Update this script for "
|
| 106 |
+
f"{type(responses).__name__}: {responses!r}"
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def make_integer_comma_eos_processor(tokenizer):
|
| 111 |
+
allowed = set()
|
| 112 |
+
for text in [str(i) for i in range(10)] + ["-", ","]:
|
| 113 |
+
tokens = encode_no_special(tokenizer, text)
|
| 114 |
+
if len(tokens) == 1:
|
| 115 |
+
allowed.add(int(tokens[0]))
|
| 116 |
+
else:
|
| 117 |
+
print(f"Skipping multi-token allowed text {text!r}: {tokens}")
|
| 118 |
+
if tokenizer.eos_token_id is not None:
|
| 119 |
+
allowed.add(int(tokenizer.eos_token_id))
|
| 120 |
+
allowed_ids = sorted(allowed)
|
| 121 |
+
mask_cache = {}
|
| 122 |
+
|
| 123 |
+
def processor(_tokens, logits):
|
| 124 |
+
vocab_size = logits.shape[-1]
|
| 125 |
+
mask = mask_cache.get(vocab_size)
|
| 126 |
+
if mask is None:
|
| 127 |
+
values = [-1e9] * vocab_size
|
| 128 |
+
for token_id in allowed_ids:
|
| 129 |
+
if 0 <= token_id < vocab_size:
|
| 130 |
+
values[token_id] = 0.0
|
| 131 |
+
mask = mx.array(values, dtype=logits.dtype)
|
| 132 |
+
mask_cache[vocab_size] = mask
|
| 133 |
+
return logits + mask[None, :]
|
| 134 |
+
|
| 135 |
+
return processor, allowed_ids
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def run_generation_queue(
|
| 139 |
+
*,
|
| 140 |
+
model,
|
| 141 |
+
tokenizer,
|
| 142 |
+
prompts,
|
| 143 |
+
answers,
|
| 144 |
+
row_indices,
|
| 145 |
+
stop_tokens,
|
| 146 |
+
max_new_tokens: int,
|
| 147 |
+
batch_size: int,
|
| 148 |
+
output_file,
|
| 149 |
+
progress,
|
| 150 |
+
logits_processors=None,
|
| 151 |
+
) -> None:
|
| 152 |
+
gen = BatchGenerator(
|
| 153 |
+
model,
|
| 154 |
+
stop_tokens=normalize_stop_tokens_for_batch_generator(stop_tokens),
|
| 155 |
+
logits_processors=logits_processors,
|
| 156 |
+
completion_batch_size=batch_size,
|
| 157 |
+
prefill_batch_size=batch_size,
|
| 158 |
+
)
|
| 159 |
+
uids = gen.insert(prompts, [max_new_tokens] * len(prompts))
|
| 160 |
+
uid_to_pos = {uid: pos for pos, uid in enumerate(uids)}
|
| 161 |
+
generated_tokens = {uid: [] for uid in uids}
|
| 162 |
+
finished = set()
|
| 163 |
+
try:
|
| 164 |
+
while True:
|
| 165 |
+
responses = gen.next()
|
| 166 |
+
prompt_responses, generation_responses = split_batch_generator_responses(
|
| 167 |
+
responses
|
| 168 |
+
)
|
| 169 |
+
if not prompt_responses and not generation_responses:
|
| 170 |
+
break
|
| 171 |
+
if not generation_responses:
|
| 172 |
+
continue
|
| 173 |
+
for response in generation_responses:
|
| 174 |
+
uid = response.uid
|
| 175 |
+
if response.finish_reason != "stop":
|
| 176 |
+
generated_tokens[uid].append(int(response.token))
|
| 177 |
+
if response.finish_reason is None or uid in finished:
|
| 178 |
+
continue
|
| 179 |
+
finished.add(uid)
|
| 180 |
+
pos = uid_to_pos[uid]
|
| 181 |
+
text = tokenizer.decode(generated_tokens[uid])
|
| 182 |
+
prediction = parse_generated(text)
|
| 183 |
+
answer = answers[pos]
|
| 184 |
+
answer_int = int(answer)
|
| 185 |
+
record = {
|
| 186 |
+
"row_index": row_indices[pos],
|
| 187 |
+
"answer": answer,
|
| 188 |
+
"prediction": prediction,
|
| 189 |
+
"correct": prediction == answer_int,
|
| 190 |
+
"parsed": prediction is not None,
|
| 191 |
+
"generated_text": text,
|
| 192 |
+
"generated_tokens": generated_tokens[uid],
|
| 193 |
+
"finish_reason": response.finish_reason,
|
| 194 |
+
}
|
| 195 |
+
output_file.write(json.dumps(record) + "\n")
|
| 196 |
+
progress.update(1)
|
| 197 |
+
finally:
|
| 198 |
+
gen.close()
|
| 199 |
+
mx.clear_cache()
|
| 200 |
+
gc.collect()
|
| 201 |
+
|
| 202 |
+
if len(finished) != len(uids):
|
| 203 |
+
raise RuntimeError(f"Chunk finished {len(finished)}/{len(uids)} rows")
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def main():
|
| 207 |
+
args = parse_args()
|
| 208 |
+
started = time.perf_counter()
|
| 209 |
+
sequences, answers = load_sequences(args.data_path)
|
| 210 |
+
if args.max_examples > 0:
|
| 211 |
+
sequences = sequences[: args.max_examples]
|
| 212 |
+
answers = answers[: args.max_examples]
|
| 213 |
+
print(f"Loaded {len(answers)} sequences from {args.data_path}")
|
| 214 |
+
|
| 215 |
+
model, tokenizer = load(args.model)
|
| 216 |
+
|
| 217 |
+
sep_tokens = encode_no_special(tokenizer, ",")
|
| 218 |
+
if not sep_tokens:
|
| 219 |
+
sep_tokens = encode_no_special(tokenizer, "1,")[-1:]
|
| 220 |
+
prompts = [",".join(str(x) for x in seq) + "," for seq in sequences]
|
| 221 |
+
prompts = [tokenizer.encode(p) for p in prompts]
|
| 222 |
+
|
| 223 |
+
eval_indices = [
|
| 224 |
+
i
|
| 225 |
+
for i, prompt in enumerate(prompts)
|
| 226 |
+
if args.max_context_tokens <= 0 or len(prompt) < args.max_context_tokens
|
| 227 |
+
]
|
| 228 |
+
skipped_long = len(prompts) - len(eval_indices)
|
| 229 |
+
eval_indices = sorted(eval_indices, key=lambda i: len(prompts[i]))
|
| 230 |
+
print(
|
| 231 |
+
f"Evaluating {len(eval_indices)} rows; skipped_long={skipped_long}; "
|
| 232 |
+
f"sep_tokens={sep_tokens}; eos_token={tokenizer.eos_token_id}; "
|
| 233 |
+
f"max_new_tokens={args.max_new_tokens}; batch_size={args.batch_size}"
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
stop_tokens = [sep_tokens]
|
| 237 |
+
if tokenizer.eos_token_id is not None:
|
| 238 |
+
stop_tokens.append([tokenizer.eos_token_id])
|
| 239 |
+
logits_processors = None
|
| 240 |
+
allowed_token_ids = None
|
| 241 |
+
restrict_integer = args.restrict_digit_comma_eos or args.restrict_integer_comma_eos
|
| 242 |
+
if restrict_integer:
|
| 243 |
+
processor, allowed_token_ids = make_integer_comma_eos_processor(tokenizer)
|
| 244 |
+
logits_processors = [processor]
|
| 245 |
+
print(f"Restricting logits to integer/comma/EOS token ids: {allowed_token_ids}")
|
| 246 |
+
args.output.parent.mkdir(parents=True, exist_ok=True)
|
| 247 |
+
args.summary_output.parent.mkdir(parents=True, exist_ok=True)
|
| 248 |
+
if args.overwrite and args.output.exists():
|
| 249 |
+
args.output.unlink()
|
| 250 |
+
completed = load_completed(args.output)
|
| 251 |
+
if completed:
|
| 252 |
+
print(f"Resuming from {args.output}: {len(completed)} rows already done")
|
| 253 |
+
todo_indices = [idx for idx in eval_indices if idx not in completed]
|
| 254 |
+
|
| 255 |
+
with args.output.open("a", encoding="utf-8") as output_file:
|
| 256 |
+
with tqdm(total=len(todo_indices), desc="Generating") as progress:
|
| 257 |
+
run_generation_queue(
|
| 258 |
+
model=model,
|
| 259 |
+
tokenizer=tokenizer,
|
| 260 |
+
prompts=[prompts[i] for i in todo_indices],
|
| 261 |
+
answers=[answers[i] for i in todo_indices],
|
| 262 |
+
row_indices=todo_indices,
|
| 263 |
+
stop_tokens=stop_tokens,
|
| 264 |
+
max_new_tokens=args.max_new_tokens,
|
| 265 |
+
batch_size=args.batch_size,
|
| 266 |
+
output_file=output_file,
|
| 267 |
+
progress=progress,
|
| 268 |
+
logits_processors=logits_processors,
|
| 269 |
+
)
|
| 270 |
+
output_file.flush()
|
| 271 |
+
|
| 272 |
+
records = load_completed(args.output)
|
| 273 |
+
eval_set = set(eval_indices)
|
| 274 |
+
records = {idx: record for idx, record in records.items() if idx in eval_set}
|
| 275 |
+
correct = sum(1 for record in records.values() if record["correct"])
|
| 276 |
+
parsed = sum(1 for record in records.values() if record["parsed"])
|
| 277 |
+
total = len(eval_indices)
|
| 278 |
+
evaluated = len(records)
|
| 279 |
+
elapsed = time.perf_counter() - started
|
| 280 |
+
|
| 281 |
+
print(f"Documents: {len(answers)}")
|
| 282 |
+
print(f"Evaluated: {evaluated}/{total}")
|
| 283 |
+
print(f"Skipped long: {skipped_long}")
|
| 284 |
+
print(f"Parsed predictions: {parsed}/{evaluated}")
|
| 285 |
+
print(f"Accuracy: {correct}/{evaluated} = {correct / evaluated:.4f}")
|
| 286 |
+
summary = {
|
| 287 |
+
"data_path": str(args.data_path),
|
| 288 |
+
"model": args.model,
|
| 289 |
+
"output": str(args.output),
|
| 290 |
+
"documents": len(answers),
|
| 291 |
+
"evaluated": evaluated,
|
| 292 |
+
"expected_evaluated": total,
|
| 293 |
+
"skipped_long": skipped_long,
|
| 294 |
+
"parsed": parsed,
|
| 295 |
+
"correct": correct,
|
| 296 |
+
"accuracy": correct / evaluated if evaluated else 0.0,
|
| 297 |
+
"max_new_tokens": args.max_new_tokens,
|
| 298 |
+
"max_context_tokens": args.max_context_tokens,
|
| 299 |
+
"batch_size": args.batch_size,
|
| 300 |
+
"restrict_digit_comma_eos": args.restrict_digit_comma_eos,
|
| 301 |
+
"restrict_integer_comma_eos": restrict_integer,
|
| 302 |
+
"allowed_token_ids": allowed_token_ids,
|
| 303 |
+
"seconds": elapsed,
|
| 304 |
+
}
|
| 305 |
+
args.summary_output.write_text(json.dumps(summary, indent=2) + "\n", encoding="utf-8")
|
| 306 |
+
print(f"Wrote {args.output}")
|
| 307 |
+
print(f"Wrote {args.summary_output}")
|
| 308 |
+
|
| 309 |
+
if __name__ == "__main__":
|
| 310 |
+
main()
|
radar_chart.png
ADDED
|
Git LFS Details
|