Text Generation
Transformers
Safetensors
MLX
qwen3
oeis
integer-sequences
causal-lm
text-generation-inference
Instructions to use N8Programs/NextTerm-440M with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use N8Programs/NextTerm-440M with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="N8Programs/NextTerm-440M")# Load model directly from transformers import AutoTokenizer, AutoModelForCausalLM tokenizer = AutoTokenizer.from_pretrained("N8Programs/NextTerm-440M") model = AutoModelForCausalLM.from_pretrained("N8Programs/NextTerm-440M") - MLX
How to use N8Programs/NextTerm-440M with MLX:
# Make sure mlx-lm is installed # pip install --upgrade mlx-lm # if on a CUDA device, also pip install mlx[cuda] # Generate text with mlx-lm from mlx_lm import load, generate model, tokenizer = load("N8Programs/NextTerm-440M") prompt = "Once upon a time in" text = generate(model, tokenizer, prompt=prompt, verbose=True) - Inference
- Notebooks
- Google Colab
- Kaggle
- Local Apps
- LM Studio
- vLLM
How to use N8Programs/NextTerm-440M with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "N8Programs/NextTerm-440M" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "N8Programs/NextTerm-440M", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/N8Programs/NextTerm-440M
- SGLang
How to use N8Programs/NextTerm-440M with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "N8Programs/NextTerm-440M" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "N8Programs/NextTerm-440M", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "N8Programs/NextTerm-440M" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "N8Programs/NextTerm-440M", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - MLX LM
How to use N8Programs/NextTerm-440M with MLX LM:
Generate or start a chat session
# Install MLX LM uv tool install mlx-lm # Generate some text mlx_lm.generate --model "N8Programs/NextTerm-440M" --prompt "Once upon a time"
- Docker Model Runner
How to use N8Programs/NextTerm-440M with Docker Model Runner:
docker model run hf.co/N8Programs/NextTerm-440M
Bundle evaluation datasets with model card scripts
Browse files- README.md +3 -1
- arithmetic_eval.py +14 -2
- eval_m1_competition_mape_mlx.py +14 -2
- eval_m1_monthly_mape_mlx.py +583 -0
- m1_competition_111.jsonl +0 -0
- m1_monthly_dataset.txt +0 -0
- oeis_eval_mlx_neo.py +14 -2
- oeis_val_neo.jsonl +0 -0
- oeis_val_neo.meta.json +9 -0
- oeis_val_neo_excluded_exact_packed_matches.tsv +74 -0
README.md
CHANGED
|
@@ -142,11 +142,13 @@ text before the first comma as the predicted integer.
|
|
| 142 |
## Reproducibility
|
| 143 |
|
| 144 |
This repository contains the local evaluation scripts and artifacts used for the
|
| 145 |
-
results above, including:
|
| 146 |
|
| 147 |
- `oeis_eval_mlx_neo.py` for OEIS-Eval-Neo with MLX batch generation.
|
| 148 |
- `arithmetic_eval.py` for arithmetic/quadratic/cubic/quartic continuation.
|
| 149 |
- `eval_m1_competition_mape_mlx.py` for M1 Competition 111 MAPE.
|
|
|
|
|
|
|
| 150 |
- `eval_results.txt` for the compact result table.
|
| 151 |
|
| 152 |
|
|
|
|
| 142 |
## Reproducibility
|
| 143 |
|
| 144 |
This repository contains the local evaluation scripts and artifacts used for the
|
| 145 |
+
results above, including the small evaluation datasets needed to rerun them:
|
| 146 |
|
| 147 |
- `oeis_eval_mlx_neo.py` for OEIS-Eval-Neo with MLX batch generation.
|
| 148 |
- `arithmetic_eval.py` for arithmetic/quadratic/cubic/quartic continuation.
|
| 149 |
- `eval_m1_competition_mape_mlx.py` for M1 Competition 111 MAPE.
|
| 150 |
+
- `oeis_val_neo.jsonl` for OEIS-Eval-Neo.
|
| 151 |
+
- `m1_competition_111.jsonl` for M1 Competition 111.
|
| 152 |
- `eval_results.txt` for the compact result table.
|
| 153 |
|
| 154 |
|
arithmetic_eval.py
CHANGED
|
@@ -19,6 +19,7 @@ import argparse
|
|
| 19 |
import csv
|
| 20 |
import random
|
| 21 |
from dataclasses import dataclass
|
|
|
|
| 22 |
from typing import Dict, List, Tuple
|
| 23 |
|
| 24 |
from mlx_lm import load
|
|
@@ -26,8 +27,19 @@ from mlx_lm.generate import BatchGenerator
|
|
| 26 |
from mlx_lm.tuner.utils import print_trainable_parameters
|
| 27 |
from tqdm import tqdm
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
MAX_NEW_TOKENS = 20
|
| 32 |
DEFAULT_MAX_TERMS = 25
|
| 33 |
SAMPLES_PER_K = 200
|
|
|
|
| 19 |
import csv
|
| 20 |
import random
|
| 21 |
from dataclasses import dataclass
|
| 22 |
+
from pathlib import Path
|
| 23 |
from typing import Dict, List, Tuple
|
| 24 |
|
| 25 |
from mlx_lm import load
|
|
|
|
| 27 |
from mlx_lm.tuner.utils import print_trainable_parameters
|
| 28 |
from tqdm import tqdm
|
| 29 |
|
| 30 |
+
SCRIPT_DIR = Path(__file__).resolve().parent
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def default_model_path() -> str:
|
| 34 |
+
if (SCRIPT_DIR / "model.safetensors").exists():
|
| 35 |
+
return str(SCRIPT_DIR)
|
| 36 |
+
local_model = SCRIPT_DIR / "NextTerm-47M"
|
| 37 |
+
if local_model.exists():
|
| 38 |
+
return str(local_model)
|
| 39 |
+
return "N8Programs/NextTerm-47M"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
MODEL_NAME = default_model_path()
|
| 43 |
MAX_NEW_TOKENS = 20
|
| 44 |
DEFAULT_MAX_TERMS = 25
|
| 45 |
SAMPLES_PER_K = 200
|
eval_m1_competition_mape_mlx.py
CHANGED
|
@@ -28,8 +28,20 @@ from eval_m1_monthly_mape_mlx import (
|
|
| 28 |
)
|
| 29 |
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
OUTPUT_PATH = Path("m1_eval_results/m1_competition111_nextterm440m_greedy_per_series.jsonl")
|
| 34 |
SUMMARY_PATH = Path("m1_eval_results/m1_competition111_nextterm440m_greedy_summary.json")
|
| 35 |
|
|
|
|
| 28 |
)
|
| 29 |
|
| 30 |
|
| 31 |
+
SCRIPT_DIR = Path(__file__).resolve().parent
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def default_model_path() -> Path:
|
| 35 |
+
if (SCRIPT_DIR / "model.safetensors").exists():
|
| 36 |
+
return SCRIPT_DIR
|
| 37 |
+
local_model = SCRIPT_DIR / "NextTerm-440M"
|
| 38 |
+
if local_model.exists():
|
| 39 |
+
return local_model
|
| 40 |
+
return Path("N8Programs/NextTerm-440M")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
DATA_PATH = SCRIPT_DIR / "m1_competition_111.jsonl"
|
| 44 |
+
MODEL_PATH = default_model_path()
|
| 45 |
OUTPUT_PATH = Path("m1_eval_results/m1_competition111_nextterm440m_greedy_per_series.jsonl")
|
| 46 |
SUMMARY_PATH = Path("m1_eval_results/m1_competition111_nextterm440m_greedy_summary.json")
|
| 47 |
|
eval_m1_monthly_mape_mlx.py
ADDED
|
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Evaluate NextTerm-style MLX models on the M1 monthly forecasting dataset."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import gc
|
| 8 |
+
import json
|
| 9 |
+
import math
|
| 10 |
+
import re
|
| 11 |
+
import time
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from decimal import Decimal, InvalidOperation, localcontext
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from statistics import mean, median
|
| 16 |
+
|
| 17 |
+
import mlx.core as mx
|
| 18 |
+
from mlx_lm import load
|
| 19 |
+
from mlx_lm.generate import BatchGenerator
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
SCRIPT_DIR = Path(__file__).resolve().parent
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def default_model_path() -> Path:
|
| 27 |
+
if (SCRIPT_DIR / "model.safetensors").exists():
|
| 28 |
+
return SCRIPT_DIR
|
| 29 |
+
local_model = SCRIPT_DIR / "NextTerm-47M"
|
| 30 |
+
if local_model.exists():
|
| 31 |
+
return local_model
|
| 32 |
+
return Path("N8Programs/NextTerm-47M")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
DATA_PATH = SCRIPT_DIR / "m1_monthly_dataset.txt"
|
| 36 |
+
MODEL_PATH = default_model_path()
|
| 37 |
+
OUTPUT_PATH = Path("m1_eval_results/m1_monthly_nextterm47m_per_series.jsonl")
|
| 38 |
+
SUMMARY_PATH = Path("m1_eval_results/m1_monthly_nextterm47m_summary.json")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass
|
| 42 |
+
class SeriesRecord:
|
| 43 |
+
row_index: int
|
| 44 |
+
series_name: str
|
| 45 |
+
start_timestamp: str
|
| 46 |
+
raw_values: list[str]
|
| 47 |
+
values: list[Decimal]
|
| 48 |
+
context_values: list[Decimal]
|
| 49 |
+
target_values: list[Decimal]
|
| 50 |
+
scale: int
|
| 51 |
+
scaled_context: list[int]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class SuppressTokenLogits:
|
| 55 |
+
"""Set selected token logits to a large negative value."""
|
| 56 |
+
|
| 57 |
+
def __init__(self, token_ids: list[int]):
|
| 58 |
+
self.token_ids = sorted({int(t) for t in token_ids if t is not None and int(t) >= 0})
|
| 59 |
+
self._bias_by_width: dict[int, mx.array] = {}
|
| 60 |
+
|
| 61 |
+
def __call__(self, tokens: mx.array, logits: mx.array) -> mx.array:
|
| 62 |
+
width = int(logits.shape[-1])
|
| 63 |
+
bias = self._bias_by_width.get(width)
|
| 64 |
+
if bias is None:
|
| 65 |
+
values = [0.0] * width
|
| 66 |
+
for token_id in self.token_ids:
|
| 67 |
+
if token_id < width:
|
| 68 |
+
values[token_id] = -1.0e9
|
| 69 |
+
bias = mx.array(values, dtype=logits.dtype)
|
| 70 |
+
self._bias_by_width[width] = bias
|
| 71 |
+
return logits + bias
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def parse_decimal(raw: str) -> Decimal:
|
| 75 |
+
raw = raw.strip()
|
| 76 |
+
try:
|
| 77 |
+
value = Decimal(raw)
|
| 78 |
+
except InvalidOperation as exc:
|
| 79 |
+
raise ValueError(f"Could not parse decimal value {raw!r}") from exc
|
| 80 |
+
if not value.is_finite():
|
| 81 |
+
raise ValueError(f"Non-finite decimal value {raw!r}")
|
| 82 |
+
return value
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def context_scale(raw_context_values: list[str]) -> int:
|
| 86 |
+
"""Choose an integer scale using only visible decimal precision in context."""
|
| 87 |
+
max_places = 0
|
| 88 |
+
for raw in raw_context_values:
|
| 89 |
+
value = parse_decimal(raw)
|
| 90 |
+
exponent = value.as_tuple().exponent
|
| 91 |
+
if exponent < 0:
|
| 92 |
+
max_places = max(max_places, -exponent)
|
| 93 |
+
return 10**max_places
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def scale_decimal_to_int(value: Decimal, scale: int) -> int:
|
| 97 |
+
scaled = value * Decimal(scale)
|
| 98 |
+
with localcontext() as ctx:
|
| 99 |
+
ctx.prec = max(50, len(scaled.as_tuple().digits) + 10)
|
| 100 |
+
rounded = scaled.to_integral_value()
|
| 101 |
+
if rounded != scaled:
|
| 102 |
+
# This can only happen if the held-out side has more precision than the
|
| 103 |
+
# context-derived scale. It is fine for scoring, but not for prompting.
|
| 104 |
+
raise ValueError(f"Context scale {scale} does not make {value} integral")
|
| 105 |
+
return int(rounded)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def load_m1_monthly(path: Path, horizon: int | None = None) -> tuple[list[SeriesRecord], int]:
|
| 109 |
+
records: list[SeriesRecord] = []
|
| 110 |
+
parsed_horizon: int | None = horizon
|
| 111 |
+
in_data = False
|
| 112 |
+
with path.open("r", encoding="latin-1", newline=None) as f:
|
| 113 |
+
for raw_line in f:
|
| 114 |
+
line = raw_line.strip()
|
| 115 |
+
if not line:
|
| 116 |
+
continue
|
| 117 |
+
if line.startswith("@horizon") and parsed_horizon is None:
|
| 118 |
+
parts = line.split()
|
| 119 |
+
if len(parts) >= 2:
|
| 120 |
+
parsed_horizon = int(parts[1])
|
| 121 |
+
continue
|
| 122 |
+
if line == "@data":
|
| 123 |
+
in_data = True
|
| 124 |
+
continue
|
| 125 |
+
if not in_data or line.startswith("#") or line.startswith("@"):
|
| 126 |
+
continue
|
| 127 |
+
|
| 128 |
+
try:
|
| 129 |
+
series_name, start_timestamp, values_blob = line.split(":", 2)
|
| 130 |
+
except ValueError as exc:
|
| 131 |
+
raise ValueError(f"Malformed M1 data line: {line[:120]!r}") from exc
|
| 132 |
+
raw_values = [part.strip() for part in values_blob.split(",") if part.strip()]
|
| 133 |
+
values = [parse_decimal(part) for part in raw_values]
|
| 134 |
+
if parsed_horizon is None:
|
| 135 |
+
raise ValueError("No horizon provided and no @horizon metadata found")
|
| 136 |
+
if len(values) <= parsed_horizon:
|
| 137 |
+
continue
|
| 138 |
+
|
| 139 |
+
raw_context = raw_values[:-parsed_horizon]
|
| 140 |
+
scale = context_scale(raw_context)
|
| 141 |
+
context_values = values[:-parsed_horizon]
|
| 142 |
+
target_values = values[-parsed_horizon:]
|
| 143 |
+
scaled_context = [scale_decimal_to_int(v, scale) for v in context_values]
|
| 144 |
+
records.append(
|
| 145 |
+
SeriesRecord(
|
| 146 |
+
row_index=len(records),
|
| 147 |
+
series_name=series_name,
|
| 148 |
+
start_timestamp=start_timestamp,
|
| 149 |
+
raw_values=raw_values,
|
| 150 |
+
values=values,
|
| 151 |
+
context_values=context_values,
|
| 152 |
+
target_values=target_values,
|
| 153 |
+
scale=scale,
|
| 154 |
+
scaled_context=scaled_context,
|
| 155 |
+
)
|
| 156 |
+
)
|
| 157 |
+
if parsed_horizon is None:
|
| 158 |
+
raise ValueError("No horizon provided and no @horizon metadata found")
|
| 159 |
+
return records, parsed_horizon
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def parse_generated_terms(text: str, limit: int) -> tuple[list[int], bool]:
|
| 163 |
+
terms: list[int] = []
|
| 164 |
+
current: list[str] = []
|
| 165 |
+
malformed = False
|
| 166 |
+
|
| 167 |
+
def flush_current() -> None:
|
| 168 |
+
nonlocal malformed
|
| 169 |
+
if not current:
|
| 170 |
+
return
|
| 171 |
+
token = "".join(current)
|
| 172 |
+
current.clear()
|
| 173 |
+
if token == "-":
|
| 174 |
+
malformed = True
|
| 175 |
+
return
|
| 176 |
+
try:
|
| 177 |
+
terms.append(int(token))
|
| 178 |
+
except ValueError:
|
| 179 |
+
malformed = True
|
| 180 |
+
|
| 181 |
+
for ch in text:
|
| 182 |
+
if len(terms) >= limit:
|
| 183 |
+
break
|
| 184 |
+
if ch.isdigit() or (ch == "-" and not current):
|
| 185 |
+
current.append(ch)
|
| 186 |
+
elif ch == ",":
|
| 187 |
+
flush_current()
|
| 188 |
+
elif ch.isspace():
|
| 189 |
+
continue
|
| 190 |
+
else:
|
| 191 |
+
malformed = True
|
| 192 |
+
flush_current()
|
| 193 |
+
break
|
| 194 |
+
if len(terms) < limit:
|
| 195 |
+
flush_current()
|
| 196 |
+
return terms[:limit], malformed
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def as_float(value: Decimal) -> float:
|
| 200 |
+
return float(value)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def ape(actual: Decimal, prediction: Decimal) -> float | None:
|
| 204 |
+
if actual == 0:
|
| 205 |
+
return None
|
| 206 |
+
return float(abs(actual - prediction) / abs(actual) * Decimal(100))
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def mape_for_predictions(
|
| 210 |
+
actuals: list[Decimal],
|
| 211 |
+
predictions: list[Decimal | None],
|
| 212 |
+
*,
|
| 213 |
+
missing_policy: str,
|
| 214 |
+
fallback: Decimal,
|
| 215 |
+
) -> tuple[float | None, list[float | None], int]:
|
| 216 |
+
apes: list[float | None] = []
|
| 217 |
+
missing = 0
|
| 218 |
+
for actual, prediction in zip(actuals, predictions):
|
| 219 |
+
pred = prediction
|
| 220 |
+
if pred is None:
|
| 221 |
+
missing += 1
|
| 222 |
+
if missing_policy == "skip":
|
| 223 |
+
apes.append(None)
|
| 224 |
+
continue
|
| 225 |
+
if missing_policy == "zero":
|
| 226 |
+
pred = Decimal(0)
|
| 227 |
+
elif missing_policy == "last_context":
|
| 228 |
+
pred = fallback
|
| 229 |
+
else:
|
| 230 |
+
raise ValueError(f"Unknown missing policy: {missing_policy}")
|
| 231 |
+
apes.append(ape(actual, pred))
|
| 232 |
+
valid = [x for x in apes if x is not None and math.isfinite(x)]
|
| 233 |
+
return (mean(valid) if valid else None), apes, missing
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def seasonal_naive_predictions(context: list[Decimal], horizon: int, season: int = 12) -> list[Decimal]:
|
| 237 |
+
if not context:
|
| 238 |
+
return [Decimal(0)] * horizon
|
| 239 |
+
if len(context) < season:
|
| 240 |
+
return [context[-1]] * horizon
|
| 241 |
+
last_season = context[-season:]
|
| 242 |
+
return [last_season[i % season] for i in range(horizon)]
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def last_value_predictions(context: list[Decimal], horizon: int) -> list[Decimal]:
|
| 246 |
+
fallback = context[-1] if context else Decimal(0)
|
| 247 |
+
return [fallback] * horizon
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def load_completed(path: Path) -> dict[int, dict]:
|
| 251 |
+
completed: dict[int, dict] = {}
|
| 252 |
+
if not path.exists():
|
| 253 |
+
return completed
|
| 254 |
+
with path.open("r", encoding="utf-8") as f:
|
| 255 |
+
for line in f:
|
| 256 |
+
if not line.strip():
|
| 257 |
+
continue
|
| 258 |
+
record = json.loads(line)
|
| 259 |
+
completed[int(record["row_index"])] = record
|
| 260 |
+
return completed
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def aggregate_summary(records: list[dict], horizon: int) -> dict:
|
| 264 |
+
def collect(key: str) -> list[float]:
|
| 265 |
+
return [
|
| 266 |
+
float(r[key])
|
| 267 |
+
for r in records
|
| 268 |
+
if r.get(key) is not None and math.isfinite(float(r[key]))
|
| 269 |
+
]
|
| 270 |
+
|
| 271 |
+
model_series = collect("mape")
|
| 272 |
+
seasonal_series = collect("seasonal_naive_mape")
|
| 273 |
+
last_series = collect("last_value_mape")
|
| 274 |
+
per_horizon: list[dict] = []
|
| 275 |
+
for h in range(horizon):
|
| 276 |
+
vals = []
|
| 277 |
+
seasonal_vals = []
|
| 278 |
+
last_vals = []
|
| 279 |
+
for r in records:
|
| 280 |
+
for source, target in [
|
| 281 |
+
("apes", vals),
|
| 282 |
+
("seasonal_naive_apes", seasonal_vals),
|
| 283 |
+
("last_value_apes", last_vals),
|
| 284 |
+
]:
|
| 285 |
+
xs = r.get(source) or []
|
| 286 |
+
if h < len(xs) and xs[h] is not None and math.isfinite(float(xs[h])):
|
| 287 |
+
target.append(float(xs[h]))
|
| 288 |
+
per_horizon.append(
|
| 289 |
+
{
|
| 290 |
+
"horizon": h + 1,
|
| 291 |
+
"mape": mean(vals) if vals else None,
|
| 292 |
+
"seasonal_naive_mape": mean(seasonal_vals) if seasonal_vals else None,
|
| 293 |
+
"last_value_mape": mean(last_vals) if last_vals else None,
|
| 294 |
+
"n": len(vals),
|
| 295 |
+
}
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
all_apes = [
|
| 299 |
+
float(x)
|
| 300 |
+
for r in records
|
| 301 |
+
for x in (r.get("apes") or [])
|
| 302 |
+
if x is not None and math.isfinite(float(x))
|
| 303 |
+
]
|
| 304 |
+
all_seasonal_apes = [
|
| 305 |
+
float(x)
|
| 306 |
+
for r in records
|
| 307 |
+
for x in (r.get("seasonal_naive_apes") or [])
|
| 308 |
+
if x is not None and math.isfinite(float(x))
|
| 309 |
+
]
|
| 310 |
+
all_last_apes = [
|
| 311 |
+
float(x)
|
| 312 |
+
for r in records
|
| 313 |
+
for x in (r.get("last_value_apes") or [])
|
| 314 |
+
if x is not None and math.isfinite(float(x))
|
| 315 |
+
]
|
| 316 |
+
return {
|
| 317 |
+
"series_count": len(records),
|
| 318 |
+
"model_macro_mape": mean(model_series) if model_series else None,
|
| 319 |
+
"model_median_series_mape": median(model_series) if model_series else None,
|
| 320 |
+
"model_point_mape": mean(all_apes) if all_apes else None,
|
| 321 |
+
"seasonal_naive_macro_mape": mean(seasonal_series) if seasonal_series else None,
|
| 322 |
+
"seasonal_naive_point_mape": mean(all_seasonal_apes) if all_seasonal_apes else None,
|
| 323 |
+
"last_value_macro_mape": mean(last_series) if last_series else None,
|
| 324 |
+
"last_value_point_mape": mean(all_last_apes) if all_last_apes else None,
|
| 325 |
+
"parsed_full_series": sum(1 for r in records if r.get("parsed_terms", 0) >= horizon),
|
| 326 |
+
"total_missing_terms": sum(int(r.get("missing_terms", 0)) for r in records),
|
| 327 |
+
"malformed_series": sum(1 for r in records if r.get("malformed", False)),
|
| 328 |
+
"per_horizon": per_horizon,
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def parse_args() -> argparse.Namespace:
|
| 333 |
+
parser = argparse.ArgumentParser()
|
| 334 |
+
parser.add_argument("--data-path", type=Path, default=DATA_PATH)
|
| 335 |
+
parser.add_argument("--model", type=Path, default=MODEL_PATH)
|
| 336 |
+
parser.add_argument("--output", type=Path, default=OUTPUT_PATH)
|
| 337 |
+
parser.add_argument("--summary-output", type=Path, default=SUMMARY_PATH)
|
| 338 |
+
parser.add_argument("--horizon", type=int, default=None)
|
| 339 |
+
parser.add_argument("--batch-size", type=int, default=64)
|
| 340 |
+
parser.add_argument("--max-new-tokens", type=int, default=384)
|
| 341 |
+
parser.add_argument("--max-context-tokens", type=int, default=2048)
|
| 342 |
+
parser.add_argument("--max-series", type=int, default=0)
|
| 343 |
+
parser.add_argument(
|
| 344 |
+
"--suppress-eos-until-horizon",
|
| 345 |
+
action="store_true",
|
| 346 |
+
help="Suppress EOS/pad logits and stop only after the requested separator count or length cap.",
|
| 347 |
+
)
|
| 348 |
+
parser.add_argument(
|
| 349 |
+
"--rerun-incomplete",
|
| 350 |
+
action="store_true",
|
| 351 |
+
help="When resuming, rerun rows whose previous record parsed fewer than horizon terms.",
|
| 352 |
+
)
|
| 353 |
+
parser.add_argument(
|
| 354 |
+
"--missing-policy",
|
| 355 |
+
choices=["zero", "last_context", "skip"],
|
| 356 |
+
default="zero",
|
| 357 |
+
help="How to score missing/unparseable forecast terms.",
|
| 358 |
+
)
|
| 359 |
+
parser.add_argument("--overwrite", action="store_true")
|
| 360 |
+
return parser.parse_args()
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def main() -> None:
|
| 364 |
+
args = parse_args()
|
| 365 |
+
started = time.perf_counter()
|
| 366 |
+
series, horizon = load_m1_monthly(args.data_path, args.horizon)
|
| 367 |
+
if args.max_series > 0:
|
| 368 |
+
series = series[: args.max_series]
|
| 369 |
+
print(f"Loaded {len(series)} M1 monthly series from {args.data_path}; horizon={horizon}")
|
| 370 |
+
|
| 371 |
+
model, tokenizer = load(str(args.model))
|
| 372 |
+
print(f"Loaded model: {args.model}")
|
| 373 |
+
|
| 374 |
+
prompts_text = [",".join(str(x) for x in s.scaled_context) + "," for s in series]
|
| 375 |
+
prompts = [tokenizer.encode(text) for text in prompts_text]
|
| 376 |
+
sep_token = tokenizer.encode("1,")[-1]
|
| 377 |
+
eval_indices = [
|
| 378 |
+
i
|
| 379 |
+
for i, prompt in enumerate(prompts)
|
| 380 |
+
if args.max_context_tokens <= 0 or len(prompt) < args.max_context_tokens
|
| 381 |
+
]
|
| 382 |
+
skipped_long = len(prompts) - len(eval_indices)
|
| 383 |
+
eval_indices = sorted(eval_indices, key=lambda i: len(prompts[i]))
|
| 384 |
+
print(
|
| 385 |
+
f"Evaluating {len(eval_indices)} series; skipped_long={skipped_long}; "
|
| 386 |
+
f"batch_size={args.batch_size}; max_new_tokens={args.max_new_tokens}; sep_token={sep_token}"
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
args.output.parent.mkdir(parents=True, exist_ok=True)
|
| 390 |
+
args.summary_output.parent.mkdir(parents=True, exist_ok=True)
|
| 391 |
+
if args.overwrite and args.output.exists():
|
| 392 |
+
args.output.unlink()
|
| 393 |
+
|
| 394 |
+
completed = load_completed(args.output)
|
| 395 |
+
if args.rerun_incomplete:
|
| 396 |
+
incomplete = [
|
| 397 |
+
idx
|
| 398 |
+
for idx, record in completed.items()
|
| 399 |
+
if int(record.get("parsed_terms", 0)) < horizon
|
| 400 |
+
]
|
| 401 |
+
for idx in incomplete:
|
| 402 |
+
completed.pop(idx, None)
|
| 403 |
+
if incomplete:
|
| 404 |
+
print(
|
| 405 |
+
"Rerunning incomplete rows: "
|
| 406 |
+
+ ", ".join(str(idx) for idx in sorted(incomplete))
|
| 407 |
+
)
|
| 408 |
+
if completed:
|
| 409 |
+
print(f"Resuming from {args.output}: {len(completed)} series already done")
|
| 410 |
+
todo_indices = [i for i in eval_indices if i not in completed]
|
| 411 |
+
|
| 412 |
+
stop_tokens = [] if args.suppress_eos_until_horizon else [[tokenizer.eos_token_id]]
|
| 413 |
+
suppress_token_ids: list[int] = []
|
| 414 |
+
if getattr(tokenizer, "eos_token_id", None) is not None:
|
| 415 |
+
suppress_token_ids.append(tokenizer.eos_token_id)
|
| 416 |
+
if getattr(tokenizer, "pad_token_id", None) is not None:
|
| 417 |
+
if not args.suppress_eos_until_horizon:
|
| 418 |
+
stop_tokens.append([tokenizer.pad_token_id])
|
| 419 |
+
suppress_token_ids.append(tokenizer.pad_token_id)
|
| 420 |
+
suppress_processor = (
|
| 421 |
+
SuppressTokenLogits(suppress_token_ids) if args.suppress_eos_until_horizon else None
|
| 422 |
+
)
|
| 423 |
+
gen = BatchGenerator(
|
| 424 |
+
model,
|
| 425 |
+
stop_tokens=stop_tokens or None,
|
| 426 |
+
completion_batch_size=args.batch_size,
|
| 427 |
+
prefill_batch_size=args.batch_size,
|
| 428 |
+
)
|
| 429 |
+
uid_to_idx: dict[int, int] = {}
|
| 430 |
+
generated_tokens: dict[int, list[int]] = {}
|
| 431 |
+
separator_counts: dict[int, int] = {}
|
| 432 |
+
if todo_indices:
|
| 433 |
+
logits_processors = (
|
| 434 |
+
[[suppress_processor] for _ in todo_indices]
|
| 435 |
+
if suppress_processor is not None
|
| 436 |
+
else None
|
| 437 |
+
)
|
| 438 |
+
uids = gen.insert(
|
| 439 |
+
[prompts[i] for i in todo_indices],
|
| 440 |
+
[args.max_new_tokens] * len(todo_indices),
|
| 441 |
+
logits_processors=logits_processors,
|
| 442 |
+
)
|
| 443 |
+
uid_to_idx = {uid: idx for uid, idx in zip(uids, todo_indices)}
|
| 444 |
+
generated_tokens = {uid: [] for uid in uids}
|
| 445 |
+
separator_counts = {uid: 0 for uid in uids}
|
| 446 |
+
finished: set[int] = set()
|
| 447 |
+
|
| 448 |
+
try:
|
| 449 |
+
with args.output.open("a", encoding="utf-8") as out:
|
| 450 |
+
with tqdm(total=len(todo_indices), desc="Generating") as progress:
|
| 451 |
+
def finalize_uid(uid: int, finish_reason: str) -> None:
|
| 452 |
+
finished.add(uid)
|
| 453 |
+
idx = uid_to_idx[uid]
|
| 454 |
+
s = series[idx]
|
| 455 |
+
generated_text = tokenizer.decode(generated_tokens[uid])
|
| 456 |
+
scaled_terms, malformed = parse_generated_terms(generated_text, horizon)
|
| 457 |
+
predictions: list[Decimal | None] = [
|
| 458 |
+
Decimal(term) / Decimal(s.scale) for term in scaled_terms
|
| 459 |
+
]
|
| 460 |
+
predictions.extend([None] * (horizon - len(predictions)))
|
| 461 |
+
fallback = s.context_values[-1] if s.context_values else Decimal(0)
|
| 462 |
+
mape, apes, missing = mape_for_predictions(
|
| 463 |
+
s.target_values,
|
| 464 |
+
predictions,
|
| 465 |
+
missing_policy=args.missing_policy,
|
| 466 |
+
fallback=fallback,
|
| 467 |
+
)
|
| 468 |
+
seasonal_preds = seasonal_naive_predictions(s.context_values, horizon)
|
| 469 |
+
seasonal_mape, seasonal_apes, _ = mape_for_predictions(
|
| 470 |
+
s.target_values,
|
| 471 |
+
seasonal_preds,
|
| 472 |
+
missing_policy="skip",
|
| 473 |
+
fallback=fallback,
|
| 474 |
+
)
|
| 475 |
+
last_preds = last_value_predictions(s.context_values, horizon)
|
| 476 |
+
last_mape, last_apes, _ = mape_for_predictions(
|
| 477 |
+
s.target_values,
|
| 478 |
+
last_preds,
|
| 479 |
+
missing_policy="skip",
|
| 480 |
+
fallback=fallback,
|
| 481 |
+
)
|
| 482 |
+
record = {
|
| 483 |
+
"row_index": idx,
|
| 484 |
+
"series_name": s.series_name,
|
| 485 |
+
"start_timestamp": s.start_timestamp,
|
| 486 |
+
"scale": s.scale,
|
| 487 |
+
"context_length": len(s.context_values),
|
| 488 |
+
"prompt_tokens": len(prompts[idx]),
|
| 489 |
+
"target": [str(x) for x in s.target_values],
|
| 490 |
+
"scaled_prediction_terms": scaled_terms,
|
| 491 |
+
"prediction": [str(x) if x is not None else None for x in predictions],
|
| 492 |
+
"parsed_terms": len(scaled_terms),
|
| 493 |
+
"missing_terms": missing,
|
| 494 |
+
"malformed": malformed,
|
| 495 |
+
"mape": mape,
|
| 496 |
+
"apes": apes,
|
| 497 |
+
"seasonal_naive_mape": seasonal_mape,
|
| 498 |
+
"seasonal_naive_apes": seasonal_apes,
|
| 499 |
+
"last_value_mape": last_mape,
|
| 500 |
+
"last_value_apes": last_apes,
|
| 501 |
+
"generated_text": generated_text,
|
| 502 |
+
"finish_reason": finish_reason,
|
| 503 |
+
}
|
| 504 |
+
out.write(json.dumps(record) + "\n")
|
| 505 |
+
out.flush()
|
| 506 |
+
progress.update(1)
|
| 507 |
+
|
| 508 |
+
while todo_indices:
|
| 509 |
+
responses = gen.next()
|
| 510 |
+
if not isinstance(responses, tuple) or len(responses) != 2:
|
| 511 |
+
raise RuntimeError(
|
| 512 |
+
"Unexpected mlx_lm BatchGenerator.next() API. "
|
| 513 |
+
"Update your mlx_lm version."
|
| 514 |
+
)
|
| 515 |
+
prompt_responses, generation_responses = responses
|
| 516 |
+
if not prompt_responses and not generation_responses:
|
| 517 |
+
break
|
| 518 |
+
remove_uids: list[int] = []
|
| 519 |
+
for response in generation_responses:
|
| 520 |
+
uid = response.uid
|
| 521 |
+
if uid in finished:
|
| 522 |
+
continue
|
| 523 |
+
if response.finish_reason != "stop":
|
| 524 |
+
token = int(response.token)
|
| 525 |
+
generated_tokens[uid].append(token)
|
| 526 |
+
if token == sep_token:
|
| 527 |
+
separator_counts[uid] += 1
|
| 528 |
+
if separator_counts[uid] >= horizon:
|
| 529 |
+
finalize_uid(uid, "separator_count")
|
| 530 |
+
remove_uids.append(uid)
|
| 531 |
+
continue
|
| 532 |
+
if response.finish_reason is not None:
|
| 533 |
+
finalize_uid(uid, str(response.finish_reason))
|
| 534 |
+
if remove_uids:
|
| 535 |
+
gen.remove(remove_uids)
|
| 536 |
+
if len(finished) != len(todo_indices):
|
| 537 |
+
raise RuntimeError(f"Finished {len(finished)}/{len(todo_indices)} series")
|
| 538 |
+
finally:
|
| 539 |
+
gen.close()
|
| 540 |
+
mx.clear_cache()
|
| 541 |
+
gc.collect()
|
| 542 |
+
|
| 543 |
+
all_records = [
|
| 544 |
+
r
|
| 545 |
+
for idx, r in sorted(load_completed(args.output).items())
|
| 546 |
+
if idx in set(eval_indices)
|
| 547 |
+
]
|
| 548 |
+
aggregate = aggregate_summary(all_records, horizon)
|
| 549 |
+
elapsed = time.perf_counter() - started
|
| 550 |
+
summary = {
|
| 551 |
+
"data_path": str(args.data_path),
|
| 552 |
+
"model": str(args.model),
|
| 553 |
+
"output": str(args.output),
|
| 554 |
+
"horizon": horizon,
|
| 555 |
+
"series_loaded": len(series),
|
| 556 |
+
"series_evaluated": len(all_records),
|
| 557 |
+
"skipped_long": skipped_long,
|
| 558 |
+
"max_context_tokens": args.max_context_tokens,
|
| 559 |
+
"max_new_tokens": args.max_new_tokens,
|
| 560 |
+
"batch_size": args.batch_size,
|
| 561 |
+
"missing_policy": args.missing_policy,
|
| 562 |
+
"suppress_eos_until_horizon": args.suppress_eos_until_horizon,
|
| 563 |
+
"seconds": elapsed,
|
| 564 |
+
**aggregate,
|
| 565 |
+
}
|
| 566 |
+
args.summary_output.write_text(json.dumps(summary, indent=2) + "\n", encoding="utf-8")
|
| 567 |
+
print(json.dumps({k: summary[k] for k in [
|
| 568 |
+
"series_evaluated",
|
| 569 |
+
"model_macro_mape",
|
| 570 |
+
"model_point_mape",
|
| 571 |
+
"seasonal_naive_macro_mape",
|
| 572 |
+
"last_value_macro_mape",
|
| 573 |
+
"parsed_full_series",
|
| 574 |
+
"total_missing_terms",
|
| 575 |
+
"malformed_series",
|
| 576 |
+
"seconds",
|
| 577 |
+
]}, indent=2))
|
| 578 |
+
print(f"Wrote {args.output}")
|
| 579 |
+
print(f"Wrote {args.summary_output}")
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
if __name__ == "__main__":
|
| 583 |
+
main()
|
m1_competition_111.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
m1_monthly_dataset.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
oeis_eval_mlx_neo.py
CHANGED
|
@@ -12,8 +12,20 @@ from mlx_lm import load
|
|
| 12 |
from mlx_lm.generate import BatchGenerator
|
| 13 |
from tqdm import tqdm
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
MAX_NEW_TOKENS = 196
|
| 18 |
MAX_CONTEXT_TOKENS = 4096
|
| 19 |
BATCH_SIZE = 64
|
|
|
|
| 12 |
from mlx_lm.generate import BatchGenerator
|
| 13 |
from tqdm import tqdm
|
| 14 |
|
| 15 |
+
SCRIPT_DIR = Path(__file__).resolve().parent
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def default_model_path() -> str:
|
| 19 |
+
if (SCRIPT_DIR / "model.safetensors").exists():
|
| 20 |
+
return str(SCRIPT_DIR)
|
| 21 |
+
local_model = SCRIPT_DIR / "NextTerm-440M"
|
| 22 |
+
if local_model.exists():
|
| 23 |
+
return str(local_model)
|
| 24 |
+
return "N8Programs/NextTerm-440M"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
DATA_PATH = SCRIPT_DIR / "oeis_val_neo.jsonl"
|
| 28 |
+
MODEL_NAME = default_model_path()
|
| 29 |
MAX_NEW_TOKENS = 196
|
| 30 |
MAX_CONTEXT_TOKENS = 4096
|
| 31 |
BATCH_SIZE = 64
|
oeis_val_neo.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
oeis_val_neo.meta.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "oeis_val_neo",
|
| 3 |
+
"description": "OEIS Eval Neo: oeis_val_decontam with exact packed-sequence overlaps against the synthetic training shard removed.",
|
| 4 |
+
"source_jsonl": "/Users/natebreslow/Documents/khashabiLab/bigOEIS/oeis_val_decontam.jsonl",
|
| 5 |
+
"overlap_source_tsv": "/Users/natebreslow/Documents/khashabiLab/bigOEIS/packed_data/oeis_val_vs_synth_aug0_inv_len_13245370099_seed0.matches_with_ids.tsv",
|
| 6 |
+
"excluded_rows": 73,
|
| 7 |
+
"kept_rows": 19035,
|
| 8 |
+
"indexing": "needle_seq_id is 0-based and jsonl_line_1based = needle_seq_id + 1"
|
| 9 |
+
}
|
oeis_val_neo_excluded_exact_packed_matches.tsv
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
needle_seq_id jsonl_line_1based oeis_id match_pairs first_haystack_seq_id packed_content_tokens terms first_12_terms last_3_terms
|
| 2 |
+
545 546 A167176 2 15370293 209 105 [0, 1, 8, 0, 1, 8, 0, 1, 8, 0, 1, 8] [0, 1, 8]
|
| 3 |
+
812 813 A368056 11 5322373 13 6 [1, 2, 4, 8, 16, 17] [8, 16, 17]
|
| 4 |
+
892 893 A316981 7 4912247 17 7 [1, 1, 2, 6, 15, 40, 121] [15, 40, 121]
|
| 5 |
+
938 939 A000811 1 55273515 153 7 [16, 160, 13056, 183305216, 153746461690757120, 472614400151965797795362900461748224, 22974620880419880070513913016918297106276186594558904213019777370224066560] [153746461690757120, 472614400151965797795362900461748224, 22974620880419880070513913016918297106276186594558904213019777370224066560]
|
| 6 |
+
984 985 A166486 1 14419403 201 101 [0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1] [1, 1, 0]
|
| 7 |
+
1418 1419 A044795 1 6999629 182 39 [82, 182, 282, 382, 482, 582, 682, 782, 829, 882, 982, 1082] [3382, 3482, 3582]
|
| 8 |
+
1561 1562 A273720 1 48019539 224 28 [1, 3, 8, 21, 57, 162, 479, 1458, 4528, 14259, 45349, 145289] [2403295565913, 7966021263923, 26425616887971]
|
| 9 |
+
1870 1871 A044758 1 41462953 182 39 [45, 145, 245, 345, 445, 459, 545, 645, 745, 845, 945, 1045] [3345, 3445, 3459]
|
| 10 |
+
2152 2153 A249541 1 23907330 55 9 [3, 4, 5, 17, 257, 65537, 83623937, 4294967297, 6992962672132097] [83623937, 4294967297, 6992962672132097]
|
| 11 |
+
2375 2376 A279069 1 57156 142 56 [1, 2, 2, 4, 16, 4, 2, 18, 4, 4, 6, 12] [6, 22, 4]
|
| 12 |
+
2429 2430 A010699 1 67423087 161 81 [2, 9, 2, 9, 2, 9, 2, 9, 2, 9, 2, 9] [2, 9, 2]
|
| 13 |
+
2537 2538 A048058 2 13019375 214 51 [11, 13, 17, 23, 31, 41, 53, 67, 83, 101, 121, 143] [2363, 2461, 2561]
|
| 14 |
+
3059 3060 A051637 28 1481960 10 5 [1, 2, 3, 7, 10] [3, 7, 10]
|
| 15 |
+
3298 3299 A179101 5 17275932 15 7 [2, 3, 5, 6, 8, 11, 14] [8, 11, 14]
|
| 16 |
+
3312 3313 A084528 1 23488569 24 8 [1, 2, 5, 17, 59, 201, 703, 2405] [201, 703, 2405]
|
| 17 |
+
3440 3441 A122059 3 4681413 17 9 [1, 0, 0, 1, 1, 2, 3, 0, 4] [3, 0, 4]
|
| 18 |
+
3507 3508 A047578 1 9038770 194 62 [2, 5, 6, 7, 10, 13, 14, 15, 18, 21, 22, 23] [119, 122, 125]
|
| 19 |
+
4130 4131 A228407 1 50962227 192 59 [0, 11, 1, 10, 100, 12, 2, 20, 101, 22, 3, 13] [134, 143, 314]
|
| 20 |
+
4236 4237 A063160 4 13631973 193 48 [10, 33, 57, 81, 105, 129, 153, 177, 201, 225, 249, 273] [1089, 1113, 1137]
|
| 21 |
+
4572 4573 A009836 2 58624055 170 14 [0, 2, 8, 368, 16512, 1583104, 199552000, 36445579264, 8620299812864, 2621816292114432, 987354046567284736, 452732308336619290624] [452732308336619290624, 247917555997339251900416, 159904531672039230122491904]
|
| 22 |
+
4682 4683 A145039 1 5184810 98 18 [3, 7, 19, 31, 107, 127, 607, 1279, 2203, 4423, 86243, 110503] [20996011, 24036583, 25964951]
|
| 23 |
+
5618 5619 A263858 1 45502052 36 17 [1, 1, 1, 1, 1, 3, 1, 1, 7, 6, 2, 1] [18, 4, 2]
|
| 24 |
+
6137 6138 A088689 2 17024337 209 105 [0, 1, 1, 0, 2, 2, 0, 1, 1, 0, 2, 2] [0, 1, 1]
|
| 25 |
+
7257 7258 A044615 1 29890298 186 41 [47, 111, 175, 239, 303, 367, 383, 431, 495, 559, 623, 687] [2223, 2287, 2351]
|
| 26 |
+
7653 7654 A062744 1 41839518 135 15 [1, 1, 10, 145, 2470, 46060, 910252, 18730855, 397089550, 8612835715, 190223180840, 4263421511271] [96723482198980, 2216905597676000, 51256802757808320]
|
| 27 |
+
8155 8156 A270840 1 58982447 204 26 [12, 36, 138, 270, 546, 4800, 7560, 12840, 14700, 358200, 678480, 16139970] [5088964800, 6974736756, 9214178820]
|
| 28 |
+
8251 8252 A155645 1 1385507 193 20 [1, 12, 84, 558, 3696, 24582, 164304, 1103478, 7444416, 50431302, 342941424, 2340123798] [249557173431942, 1729973554578864, 12008254925383638]
|
| 29 |
+
8279 8280 A290453 4 29575809 181 91 [1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 2] [0, 1, 0]
|
| 30 |
+
8477 8478 A044402 3 19293444 182 39 [70, 170, 270, 370, 470, 570, 670, 700, 770, 870, 970, 1070] [3370, 3470, 3570]
|
| 31 |
+
9025 9026 A068987 1 62874440 85 11 [2, 149, 1925, 13808, 49703, 2458886, 9470345, 186557267, 523551503, 191278379840, 4368196101672] [523551503, 191278379840, 4368196101672]
|
| 32 |
+
9707 9708 A212167 3 33046 191 66 [1, 2, 3, 5, 6, 7, 10, 11, 12, 13, 14, 15] [79, 82, 83]
|
| 33 |
+
9721 9722 A293077 3 14353436 199 35 [2, 4, 6, 10, 16, 26, 44, 74, 126, 214, 364, 620] [46697496, 79717612, 136086476]
|
| 34 |
+
9894 9895 A009705 5 24197587 164 14 [0, 2, 4, 302, 8104, 947642, 86855404, 16203909542, 3130092938704, 896924477276402, 290713861720990804, 121467176505314129822] [121467176505314129822, 58492863120535523766904, 33925794100542844193202602]
|
| 35 |
+
10305 10306 A135839 2 3672765 181 91 [1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0] [1, 0, 1]
|
| 36 |
+
10600 10601 A101587 2 594093 59 14 [2, 4, 5, 20, 308, 815, 857, 1114, 1418, 1688, 12008, 18692] [18692, 28097, 90964]
|
| 37 |
+
10681 10682 A024036 3 21353804 219 25 [0, 3, 15, 63, 255, 1023, 4095, 16383, 65535, 262143, 1048575, 4194303] [17592186044415, 70368744177663, 281474976710655]
|
| 38 |
+
10706 10707 A100285 1 18949167 179 90 [1, 1, 5, 5, 1, 1, 5, 5, 1, 1, 5, 5] [5, 1, 1]
|
| 39 |
+
11007 11008 A055843 1 3608436 231 31 [1, 13, 85, 385, 1375, 4147, 11011, 26455, 58630, 121550, 238238, 445094] [406833460, 536222500, 700950052]
|
| 40 |
+
11079 11080 A017629 2 52873588 202 53 [9, 21, 33, 45, 57, 69, 81, 93, 105, 117, 129, 141] [609, 621, 633]
|
| 41 |
+
11234 11235 A007520 1 28251297 202 51 [3, 11, 19, 43, 59, 67, 83, 107, 131, 139, 163, 179] [1163, 1171, 1187]
|
| 42 |
+
11307 11308 A299729 1 7973728 201 54 [12, 24, 30, 36, 40, 48, 60, 63, 70, 72, 80, 84] [320, 324, 325]
|
| 43 |
+
11450 11451 A010126 1 52767868 161 81 [4, 1, 2, 4, 2, 1, 8, 1, 2, 4, 2, 1] [8, 1, 2]
|
| 44 |
+
12022 12023 A380991 9 1893358 9 4 [1, 4, 15, 48] [4, 15, 48]
|
| 45 |
+
12091 12092 A047612 2 10789238 196 63 [0, 2, 4, 5, 8, 10, 12, 13, 16, 18, 20, 21] [120, 122, 124]
|
| 46 |
+
12305 12306 A063144 1 8491582 189 49 [8, 27, 47, 67, 87, 107, 127, 147, 167, 187, 207, 227] [927, 947, 967]
|
| 47 |
+
12646 12647 A129756 1 3311430 207 76 [1, 1, 1, 1, 3, 3, 3, 3, 5, 5, 5, 5] [37, 37, 37]
|
| 48 |
+
13501 13502 A050621 1 20013751 188 18 [2, 12, 104, 1008, 10016, 100032, 1000064, 10000128, 100000256, 1000000512, 10000001024, 100000002048] [1000000000032768, 10000000000065536, 100000000000131072]
|
| 49 |
+
13614 13615 A126048 3 11143698 95 48 [2, 3, 5, 7, 5, 1, 3, 7, 5, 1, 3, 7] [1, 1, 1]
|
| 50 |
+
13703 13704 A163471 2 46480945 194 20 [3, 18, 114, 744, 4932, 32952, 221016, 1485216, 9989808, 67223328, 452457504, 3045661824] [283481998053888, 1908413999583744, 12847536038651904]
|
| 51 |
+
13704 13705 A309224 3 10124493 99 35 [1, 1, 0, 0, 2, 1, 1, 2, 2, 3, 5, 7] [185, 202, 222]
|
| 52 |
+
13737 13738 A101540 1 33989933 43 10 [4, 56, 500, 514, 626, 640, 724, 53110, 65330, 109672] [53110, 65330, 109672]
|
| 53 |
+
14074 14075 A160124 1 19649603 206 57 [0, 0, 0, 2, 4, 4, 8, 18, 24, 24, 28, 36] [932, 1000, 1028]
|
| 54 |
+
14081 14082 A010704 2 33797649 161 81 [3, 6, 3, 6, 3, 6, 3, 6, 3, 6, 3, 6] [3, 6, 3]
|
| 55 |
+
14379 14380 A327424 2 605474 23 8 [1, 1, 2, 4, 10, 33, 234, 16579] [33, 234, 16579]
|
| 56 |
+
14391 14392 A212611 2 5704190 167 24 [135, 3375, 25137, 68607, 22113, 557375, 451737, 1278316, 273, 6739509, 188325, 8775] [366776, 1487640245, 417725]
|
| 57 |
+
14435 14436 A139302 5 8333360 100 4 [2, 512, 144115188075855872, 904625697166532776746648320380374280103671755200316906558262375061821325312] [512, 144115188075855872, 904625697166532776746648320380374280103671755200316906558262375061821325312]
|
| 58 |
+
14960 14961 A057556 1 65334021 335 168 [0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0] [5, 0, 0]
|
| 59 |
+
15368 15369 A056263 1 51823997 59 13 [1, 3, 27, 155, 321, 351, 1211, 1283, 7983, 15191, 84771, 119929] [84771, 119929, 148859]
|
| 60 |
+
15953 15954 A007097 1 484792 261 24 [1, 2, 3, 5, 11, 31, 127, 709, 5381, 52711, 648391, 9737333] [9332039515881088707361, 499720579610303128776791, 28785866289100396890228041]
|
| 61 |
+
16105 16106 A082705 4 11152753 23 6 [7, 9, 895, 1525, 3037, 21157] [1525, 3037, 21157]
|
| 62 |
+
16931 16932 A255771 53 771418 21 11 [1, 1, 1, 2, 2, 1, 2, 2, 4, 2, 2] [4, 2, 2]
|
| 63 |
+
17007 17008 A389124 2 37692541 230 29 [1, 3, 6, 17, 46, 133, 386, 1137, 3366, 10013, 29866, 89257] [1270955283786, 3812843481737, 11438485705966]
|
| 64 |
+
17160 17161 A117995 1 52396131 207 51 [0, 0, 1, 1, 2, 3, 4, 6, 8, 11, 14, 20] [26252, 30700, 35717]
|
| 65 |
+
17170 17171 A271449 3 5003778 13 6 [0, 3, 6, 9, 12, 18] [9, 12, 18]
|
| 66 |
+
17255 17256 A176393 1 8944574 205 61 [1, 3, 9, 13, 17, 19, 21, 25, 29, 31, 33, 37] [161, 163, 165]
|
| 67 |
+
17300 17301 A085070 11 9630624 8 4 [1, 2, 9, 26] [2, 9, 26]
|
| 68 |
+
17347 17348 A295745 1 37424184 8 3 [18, 20, 24] [18, 20, 24]
|
| 69 |
+
18185 18186 A380287 4 8734831 113 18 [4, 6, 16, 48, 142, 472, 1670, 6364, 24604, 97668, 390070, 1570560] [419930444, 1701635046, 6898183050]
|
| 70 |
+
18227 18228 A133384 1 15280426 227 19 [12, 102, 1002, 10002, 100002, 1000002, 10000002, 100000002, 1000000002, 10000000002, 100000000002, 1000000000002] [100000000000000002, 1000000000000000002, 10000000000000000002]
|
| 71 |
+
18258 18259 A011803 2 34513693 58 10 [1, 2, 8, 64, 904, 20926, 753994, 40412530, 3099627142, 329518779600] [40412530, 3099627142, 329518779600]
|
| 72 |
+
18304 18305 A137444 2 13209959 217 41 [1, 4, 6, 4, -4, -16, -24, -16, 16, 64, 96, 64] [-1572864, -1048576, 1048576]
|
| 73 |
+
18559 18560 A083318 1 45434387 198 32 [1, 3, 5, 9, 17, 33, 65, 129, 257, 513, 1025, 2049] [536870913, 1073741825, 2147483649]
|
| 74 |
+
18605 18606 A141044 2 3030285 209 105 [2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] [1, 1, 1]
|