Ivme-Conversate-22M-Base / eval_wikitext.py
ereniko's picture
Upload eval_wikitext.py with huggingface_hub
7e5763c verified
"""
WikiText-2 scorer for İvme — reports cross-entropy loss and perplexity.
The Tiny-ML leaderboard's "WikiText-2 ↓" column reports per-token cross-entropy
loss (e.g. competitors at 2.66, 3.08), NOT perplexity. We print both so you can
match whichever the leaderboard uses.
Method: concatenate the WikiText-2 test split, tokenize, and score with a
sliding window of the model's context length, summing log-probs over all
predicted tokens. CE loss = -mean(log p(token)). Perplexity = exp(CE loss).
Usage:
python eval_wikitext.py --checkpoint checkpoints/ivme_base_ema.pt
"""
from __future__ import annotations
import argparse
import json
import math
import sys
import torch
import torch.nn.functional as F
from tokenizers import Tokenizer
from datasets import load_dataset
sys.path.insert(0, ".")
from model import IvmeConfig, IvmeConversate
TOKENIZER_PATH = "ivme_tokenizer.json"
@torch.no_grad()
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--checkpoint", required=True)
ap.add_argument("--output", default="wikitext_results.json")
ap.add_argument("--stride", type=int, default=None,
help="sliding-window stride; defaults to full context (non-overlapping)")
args = ap.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tok = Tokenizer.from_file(TOKENIZER_PATH)
ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
cfg = ckpt["cfg"]
cfg.attn_backend = "sdpa"
ctx = cfg.max_seq_len
model = IvmeConversate(cfg).to(device)
model.load_state_dict(ckpt["model"])
model.eval()
print(f"[wikitext] model loaded: {model.num_params()/1e6:.1f}M on {device}")
print("[wikitext] loading WikiText-2 test split...")
ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test")
text = "\n\n".join(t for t in ds["text"] if t.strip())
ids = tok.encode(text).ids
print(f"[wikitext] {len(ids):,} tokens")
stride = args.stride or ctx
total_nll = 0.0
total_tokens = 0
for start in range(0, len(ids) - 1, stride):
chunk = ids[start : start + ctx + 1]
if len(chunk) < 2:
break
inp = torch.tensor([chunk[:-1]], dtype=torch.long, device=device)
tgt = torch.tensor([chunk[1:]], dtype=torch.long, device=device)
with torch.autocast(device_type=device.type, dtype=torch.bfloat16,
enabled=device.type == "cuda"):
logits, _ = model(inp)
logp = F.log_softmax(logits.float(), dim=-1)
tok_lp = logp[0, range(tgt.shape[1]), tgt[0]]
total_nll += -tok_lp.sum().item()
total_tokens += tgt.shape[1]
ce_loss = total_nll / total_tokens
ppl = math.exp(ce_loss)
print(f"\n{'='*52}")
print(f" WikiText-2 cross-entropy loss : {ce_loss:.4f}")
print(f" WikiText-2 perplexity : {ppl:.2f}")
print(f"{'='*52}")
print(f" (leaderboard column reports CE loss, lower is better)")
with open(args.output, "w") as f:
json.dump({"wikitext2_ce_loss": ce_loss, "wikitext2_ppl": ppl,
"tokens": total_tokens}, f, indent=2)
print(f"\n[wikitext] saved -> {args.output}")
if __name__ == "__main__":
main()