| """ |
| WikiText-2 scorer for İvme — reports cross-entropy loss and perplexity. |
| |
| The Tiny-ML leaderboard's "WikiText-2 ↓" column reports per-token cross-entropy |
| loss (e.g. competitors at 2.66, 3.08), NOT perplexity. We print both so you can |
| match whichever the leaderboard uses. |
| |
| Method: concatenate the WikiText-2 test split, tokenize, and score with a |
| sliding window of the model's context length, summing log-probs over all |
| predicted tokens. CE loss = -mean(log p(token)). Perplexity = exp(CE loss). |
| |
| Usage: |
| python eval_wikitext.py --checkpoint checkpoints/ivme_base_ema.pt |
| """ |
|
|
| from __future__ import annotations |
| import argparse |
| import json |
| import math |
| import sys |
| import torch |
| import torch.nn.functional as F |
| from tokenizers import Tokenizer |
| from datasets import load_dataset |
|
|
| sys.path.insert(0, ".") |
| from model import IvmeConfig, IvmeConversate |
|
|
| TOKENIZER_PATH = "ivme_tokenizer.json" |
|
|
|
|
| @torch.no_grad() |
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--checkpoint", required=True) |
| ap.add_argument("--output", default="wikitext_results.json") |
| ap.add_argument("--stride", type=int, default=None, |
| help="sliding-window stride; defaults to full context (non-overlapping)") |
| args = ap.parse_args() |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| tok = Tokenizer.from_file(TOKENIZER_PATH) |
|
|
| ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False) |
| cfg = ckpt["cfg"] |
| cfg.attn_backend = "sdpa" |
| ctx = cfg.max_seq_len |
| model = IvmeConversate(cfg).to(device) |
| model.load_state_dict(ckpt["model"]) |
| model.eval() |
| print(f"[wikitext] model loaded: {model.num_params()/1e6:.1f}M on {device}") |
|
|
| print("[wikitext] loading WikiText-2 test split...") |
| ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test") |
| text = "\n\n".join(t for t in ds["text"] if t.strip()) |
| ids = tok.encode(text).ids |
| print(f"[wikitext] {len(ids):,} tokens") |
|
|
| stride = args.stride or ctx |
| total_nll = 0.0 |
| total_tokens = 0 |
|
|
| for start in range(0, len(ids) - 1, stride): |
| chunk = ids[start : start + ctx + 1] |
| if len(chunk) < 2: |
| break |
| inp = torch.tensor([chunk[:-1]], dtype=torch.long, device=device) |
| tgt = torch.tensor([chunk[1:]], dtype=torch.long, device=device) |
| with torch.autocast(device_type=device.type, dtype=torch.bfloat16, |
| enabled=device.type == "cuda"): |
| logits, _ = model(inp) |
| logp = F.log_softmax(logits.float(), dim=-1) |
| tok_lp = logp[0, range(tgt.shape[1]), tgt[0]] |
| total_nll += -tok_lp.sum().item() |
| total_tokens += tgt.shape[1] |
|
|
| ce_loss = total_nll / total_tokens |
| ppl = math.exp(ce_loss) |
| print(f"\n{'='*52}") |
| print(f" WikiText-2 cross-entropy loss : {ce_loss:.4f}") |
| print(f" WikiText-2 perplexity : {ppl:.2f}") |
| print(f"{'='*52}") |
| print(f" (leaderboard column reports CE loss, lower is better)") |
|
|
| with open(args.output, "w") as f: |
| json.dump({"wikitext2_ce_loss": ce_loss, "wikitext2_ppl": ppl, |
| "tokens": total_tokens}, f, indent=2) |
| print(f"\n[wikitext] saved -> {args.output}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |