| """ |
| Batched BLiMP scorer for İvme — fast, GPU-parallel. |
| |
| Scores all 67 BLiMP subtasks by batching sentence pairs through the model |
| instead of looping one at a time. On a Blackwell this runs the whole suite |
| in well under a minute. |
| |
| Method: for each (good, bad) pair, compute total log-prob of each sentence |
| and count a win when logprob(good) > logprob(bad). Sentences are padded into |
| batches and scored with a length mask so padding contributes nothing. |
| |
| Usage: |
| python eval_blimp.py --checkpoint checkpoints/ivme_base_ema.pt |
| python eval_blimp.py --checkpoint checkpoints/ivme_base_ema.pt --batch_size 256 |
| """ |
|
|
| from __future__ import annotations |
| import argparse |
| import json |
| import sys |
| import torch |
| import torch.nn.functional as F |
| from tokenizers import Tokenizer |
| from datasets import load_dataset |
|
|
| sys.path.insert(0, ".") |
| from model import IvmeConfig, IvmeConversate |
|
|
| TOKENIZER_PATH = "ivme_tokenizer.json" |
|
|
| BLIMP_TASKS = [ |
| "adjunct_island", "anaphor_gender_agreement", "anaphor_number_agreement", |
| "animate_subject_passive", "animate_subject_trans", "causative", |
| "complex_NP_island", "coordinate_structure_constraint_complex_left_branch", |
| "coordinate_structure_constraint_object_extraction", "determiner_noun_agreement_1", |
| "determiner_noun_agreement_2", "determiner_noun_agreement_irregular_1", |
| "determiner_noun_agreement_irregular_2", "determiner_noun_agreement_with_adj_2", |
| "determiner_noun_agreement_with_adj_irregular_1", |
| "determiner_noun_agreement_with_adj_irregular_2", |
| "determiner_noun_agreement_with_adjective_1", "distractor_agreement_relational_noun", |
| "distractor_agreement_relative_clause", "drop_argument", "ellipsis_n_bar_1", |
| "ellipsis_n_bar_2", "existential_there_object_raising", |
| "existential_there_quantifiers_1", "existential_there_quantifiers_2", |
| "existential_there_subject_raising", "expletive_it_object_raising", "inchoative", |
| "intransitive", "irregular_past_participle_adjectives", |
| "irregular_past_participle_verbs", "irregular_plural_subject_verb_agreement_1", |
| "irregular_plural_subject_verb_agreement_2", "left_branch_island_echo_question", |
| "left_branch_island_simple_question", "matrix_question_npi_licensor_present", |
| "npi_present_1", "npi_present_2", "only_npi_licensor_present", "only_npi_scope", |
| "passive_1", "passive_2", "principle_A_c_command", "principle_A_case_1", |
| "principle_A_case_2", "principle_A_domain_1", "principle_A_domain_2", |
| "principle_A_domain_3", "principle_A_reconstruction", |
| "regular_plural_subject_verb_agreement_1", "regular_plural_subject_verb_agreement_2", |
| "sentential_negation_npi_licensor_present", "sentential_negation_npi_scope", |
| "sentential_subject_island", "superlative_quantifiers_1", "superlative_quantifiers_2", |
| "tough_vs_raising_1", "tough_vs_raising_2", "transitive", "wh_island", |
| "wh_questions_object_gap", "wh_questions_subject_gap", |
| "wh_questions_subject_gap_long_distance", "wh_vs_that_no_gap", |
| "wh_vs_that_no_gap_long_distance", "wh_vs_that_with_gap", |
| "wh_vs_that_with_gap_long_distance", |
| ] |
|
|
|
|
| @torch.no_grad() |
| def batch_logprobs(model, token_lists, device, pad_id, max_len): |
| """Total log-prob of each sequence in a padded batch. token_lists: list[list[int]].""" |
| B = len(token_lists) |
| L = min(max(len(t) for t in token_lists), max_len) |
| inp = torch.full((B, L), pad_id, dtype=torch.long, device=device) |
| lengths = [] |
| for i, t in enumerate(token_lists): |
| t = t[:L] |
| inp[i, : len(t)] = torch.tensor(t, dtype=torch.long, device=device) |
| lengths.append(len(t)) |
|
|
| with torch.autocast(device_type=device.type, dtype=torch.bfloat16, |
| enabled=device.type == "cuda"): |
| logits, _ = model(inp) |
|
|
| logp = F.log_softmax(logits.float(), dim=-1) |
| targets = inp[:, 1:] |
| pred = logp[:, :-1, :] |
| tok_lp = pred.gather(-1, targets.unsqueeze(-1)).squeeze(-1) |
|
|
| mask = torch.zeros_like(tok_lp) |
| for i, n in enumerate(lengths): |
| mask[i, : max(0, n - 1)] = 1.0 |
| return (tok_lp * mask).sum(dim=1) |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--checkpoint", required=True) |
| ap.add_argument("--batch_size", type=int, default=256) |
| ap.add_argument("--output", default="blimp_results.json") |
| args = ap.parse_args() |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| tok = Tokenizer.from_file(TOKENIZER_PATH) |
| pad_id = tok.token_to_id("<|pad|>") or 0 |
|
|
| ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False) |
| cfg = ckpt["cfg"] |
| cfg.attn_backend = "sdpa" |
| max_len = cfg.max_seq_len |
| model = IvmeConversate(cfg).to(device) |
| model.load_state_dict(ckpt["model"]) |
| model.eval() |
| print(f"[blimp] model loaded: {model.num_params()/1e6:.1f}M on {device}") |
|
|
| print("[blimp] loading full BLiMP dataset (one download)...") |
| full_ds = load_dataset("WillHeld/blimp", split="train") |
| by_task = {t: {"good": [], "bad": []} for t in BLIMP_TASKS} |
| for row in full_ds: |
| uid = row["UID"] |
| if uid in by_task: |
| by_task[uid]["good"].append(row["sentence_good"]) |
| by_task[uid]["bad"].append(row["sentence_bad"]) |
| print(f"[blimp] {len(full_ds)} examples bucketed into {len(BLIMP_TASKS)} subtasks\n") |
|
|
| results = {} |
| total_correct = total_examples = 0 |
|
|
| for i, task in enumerate(BLIMP_TASKS): |
| goods = by_task[task]["good"] |
| bads = by_task[task]["bad"] |
| good_tok = [tok.encode(s).ids for s in goods] |
| bad_tok = [tok.encode(s).ids for s in bads] |
|
|
| correct = 0 |
| for start in range(0, len(good_tok), args.batch_size): |
| gb = good_tok[start : start + args.batch_size] |
| bb = bad_tok[start : start + args.batch_size] |
| g_lp = batch_logprobs(model, gb, device, pad_id, max_len) |
| b_lp = batch_logprobs(model, bb, device, pad_id, max_len) |
| correct += (g_lp > b_lp).sum().item() |
|
|
| acc = correct / len(goods) |
| results[task] = acc |
| total_correct += correct |
| total_examples += len(goods) |
| running = total_correct / total_examples |
| print(f"[{i+1:02d}/{len(BLIMP_TASKS)}] {task:<55} {acc*100:5.1f}% " |
| f"(avg: {running*100:.2f}%)") |
|
|
| final = total_correct / total_examples |
| print(f"\n{'='*60}") |
| print(f" BLiMP average: {final*100:.2f}% (random baseline: 50%)") |
| print(f"{'='*60}") |
|
|
| with open(args.output, "w") as f: |
| json.dump({"tasks": results, "average": final}, f, indent=2) |
| print(f"\n[blimp] saved -> {args.output}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |