""" Batched BLiMP scorer for İvme — fast, GPU-parallel. Scores all 67 BLiMP subtasks by batching sentence pairs through the model instead of looping one at a time. On a Blackwell this runs the whole suite in well under a minute. Method: for each (good, bad) pair, compute total log-prob of each sentence and count a win when logprob(good) > logprob(bad). Sentences are padded into batches and scored with a length mask so padding contributes nothing. Usage: python eval_blimp.py --checkpoint checkpoints/ivme_base_ema.pt python eval_blimp.py --checkpoint checkpoints/ivme_base_ema.pt --batch_size 256 """ from __future__ import annotations import argparse import json import sys import torch import torch.nn.functional as F from tokenizers import Tokenizer from datasets import load_dataset sys.path.insert(0, ".") from model import IvmeConfig, IvmeConversate TOKENIZER_PATH = "ivme_tokenizer.json" BLIMP_TASKS = [ "adjunct_island", "anaphor_gender_agreement", "anaphor_number_agreement", "animate_subject_passive", "animate_subject_trans", "causative", "complex_NP_island", "coordinate_structure_constraint_complex_left_branch", "coordinate_structure_constraint_object_extraction", "determiner_noun_agreement_1", "determiner_noun_agreement_2", "determiner_noun_agreement_irregular_1", "determiner_noun_agreement_irregular_2", "determiner_noun_agreement_with_adj_2", "determiner_noun_agreement_with_adj_irregular_1", "determiner_noun_agreement_with_adj_irregular_2", "determiner_noun_agreement_with_adjective_1", "distractor_agreement_relational_noun", "distractor_agreement_relative_clause", "drop_argument", "ellipsis_n_bar_1", "ellipsis_n_bar_2", "existential_there_object_raising", "existential_there_quantifiers_1", "existential_there_quantifiers_2", "existential_there_subject_raising", "expletive_it_object_raising", "inchoative", "intransitive", "irregular_past_participle_adjectives", "irregular_past_participle_verbs", "irregular_plural_subject_verb_agreement_1", "irregular_plural_subject_verb_agreement_2", "left_branch_island_echo_question", "left_branch_island_simple_question", "matrix_question_npi_licensor_present", "npi_present_1", "npi_present_2", "only_npi_licensor_present", "only_npi_scope", "passive_1", "passive_2", "principle_A_c_command", "principle_A_case_1", "principle_A_case_2", "principle_A_domain_1", "principle_A_domain_2", "principle_A_domain_3", "principle_A_reconstruction", "regular_plural_subject_verb_agreement_1", "regular_plural_subject_verb_agreement_2", "sentential_negation_npi_licensor_present", "sentential_negation_npi_scope", "sentential_subject_island", "superlative_quantifiers_1", "superlative_quantifiers_2", "tough_vs_raising_1", "tough_vs_raising_2", "transitive", "wh_island", "wh_questions_object_gap", "wh_questions_subject_gap", "wh_questions_subject_gap_long_distance", "wh_vs_that_no_gap", "wh_vs_that_no_gap_long_distance", "wh_vs_that_with_gap", "wh_vs_that_with_gap_long_distance", ] @torch.no_grad() def batch_logprobs(model, token_lists, device, pad_id, max_len): """Total log-prob of each sequence in a padded batch. token_lists: list[list[int]].""" B = len(token_lists) L = min(max(len(t) for t in token_lists), max_len) inp = torch.full((B, L), pad_id, dtype=torch.long, device=device) lengths = [] for i, t in enumerate(token_lists): t = t[:L] inp[i, : len(t)] = torch.tensor(t, dtype=torch.long, device=device) lengths.append(len(t)) with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=device.type == "cuda"): logits, _ = model(inp) logp = F.log_softmax(logits.float(), dim=-1) targets = inp[:, 1:] pred = logp[:, :-1, :] tok_lp = pred.gather(-1, targets.unsqueeze(-1)).squeeze(-1) mask = torch.zeros_like(tok_lp) for i, n in enumerate(lengths): mask[i, : max(0, n - 1)] = 1.0 return (tok_lp * mask).sum(dim=1) def main(): ap = argparse.ArgumentParser() ap.add_argument("--checkpoint", required=True) ap.add_argument("--batch_size", type=int, default=256) ap.add_argument("--output", default="blimp_results.json") args = ap.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tok = Tokenizer.from_file(TOKENIZER_PATH) pad_id = tok.token_to_id("<|pad|>") or 0 ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False) cfg = ckpt["cfg"] cfg.attn_backend = "sdpa" max_len = cfg.max_seq_len model = IvmeConversate(cfg).to(device) model.load_state_dict(ckpt["model"]) model.eval() print(f"[blimp] model loaded: {model.num_params()/1e6:.1f}M on {device}") print("[blimp] loading full BLiMP dataset (one download)...") full_ds = load_dataset("WillHeld/blimp", split="train") by_task = {t: {"good": [], "bad": []} for t in BLIMP_TASKS} for row in full_ds: uid = row["UID"] if uid in by_task: by_task[uid]["good"].append(row["sentence_good"]) by_task[uid]["bad"].append(row["sentence_bad"]) print(f"[blimp] {len(full_ds)} examples bucketed into {len(BLIMP_TASKS)} subtasks\n") results = {} total_correct = total_examples = 0 for i, task in enumerate(BLIMP_TASKS): goods = by_task[task]["good"] bads = by_task[task]["bad"] good_tok = [tok.encode(s).ids for s in goods] bad_tok = [tok.encode(s).ids for s in bads] correct = 0 for start in range(0, len(good_tok), args.batch_size): gb = good_tok[start : start + args.batch_size] bb = bad_tok[start : start + args.batch_size] g_lp = batch_logprobs(model, gb, device, pad_id, max_len) b_lp = batch_logprobs(model, bb, device, pad_id, max_len) correct += (g_lp > b_lp).sum().item() acc = correct / len(goods) results[task] = acc total_correct += correct total_examples += len(goods) running = total_correct / total_examples print(f"[{i+1:02d}/{len(BLIMP_TASKS)}] {task:<55} {acc*100:5.1f}% " f"(avg: {running*100:.2f}%)") final = total_correct / total_examples print(f"\n{'='*60}") print(f" BLiMP average: {final*100:.2f}% (random baseline: 50%)") print(f"{'='*60}") with open(args.output, "w") as f: json.dump({"tasks": results, "average": final}, f, indent=2) print(f"\n[blimp] saved -> {args.output}") if __name__ == "__main__": main()