Ivme-Conversate-22M-Base / eval_blimp.py
ereniko's picture
Upload eval_blimp.py with huggingface_hub
337273e verified
"""
Batched BLiMP scorer for İvme — fast, GPU-parallel.
Scores all 67 BLiMP subtasks by batching sentence pairs through the model
instead of looping one at a time. On a Blackwell this runs the whole suite
in well under a minute.
Method: for each (good, bad) pair, compute total log-prob of each sentence
and count a win when logprob(good) > logprob(bad). Sentences are padded into
batches and scored with a length mask so padding contributes nothing.
Usage:
python eval_blimp.py --checkpoint checkpoints/ivme_base_ema.pt
python eval_blimp.py --checkpoint checkpoints/ivme_base_ema.pt --batch_size 256
"""
from __future__ import annotations
import argparse
import json
import sys
import torch
import torch.nn.functional as F
from tokenizers import Tokenizer
from datasets import load_dataset
sys.path.insert(0, ".")
from model import IvmeConfig, IvmeConversate
TOKENIZER_PATH = "ivme_tokenizer.json"
BLIMP_TASKS = [
"adjunct_island", "anaphor_gender_agreement", "anaphor_number_agreement",
"animate_subject_passive", "animate_subject_trans", "causative",
"complex_NP_island", "coordinate_structure_constraint_complex_left_branch",
"coordinate_structure_constraint_object_extraction", "determiner_noun_agreement_1",
"determiner_noun_agreement_2", "determiner_noun_agreement_irregular_1",
"determiner_noun_agreement_irregular_2", "determiner_noun_agreement_with_adj_2",
"determiner_noun_agreement_with_adj_irregular_1",
"determiner_noun_agreement_with_adj_irregular_2",
"determiner_noun_agreement_with_adjective_1", "distractor_agreement_relational_noun",
"distractor_agreement_relative_clause", "drop_argument", "ellipsis_n_bar_1",
"ellipsis_n_bar_2", "existential_there_object_raising",
"existential_there_quantifiers_1", "existential_there_quantifiers_2",
"existential_there_subject_raising", "expletive_it_object_raising", "inchoative",
"intransitive", "irregular_past_participle_adjectives",
"irregular_past_participle_verbs", "irregular_plural_subject_verb_agreement_1",
"irregular_plural_subject_verb_agreement_2", "left_branch_island_echo_question",
"left_branch_island_simple_question", "matrix_question_npi_licensor_present",
"npi_present_1", "npi_present_2", "only_npi_licensor_present", "only_npi_scope",
"passive_1", "passive_2", "principle_A_c_command", "principle_A_case_1",
"principle_A_case_2", "principle_A_domain_1", "principle_A_domain_2",
"principle_A_domain_3", "principle_A_reconstruction",
"regular_plural_subject_verb_agreement_1", "regular_plural_subject_verb_agreement_2",
"sentential_negation_npi_licensor_present", "sentential_negation_npi_scope",
"sentential_subject_island", "superlative_quantifiers_1", "superlative_quantifiers_2",
"tough_vs_raising_1", "tough_vs_raising_2", "transitive", "wh_island",
"wh_questions_object_gap", "wh_questions_subject_gap",
"wh_questions_subject_gap_long_distance", "wh_vs_that_no_gap",
"wh_vs_that_no_gap_long_distance", "wh_vs_that_with_gap",
"wh_vs_that_with_gap_long_distance",
]
@torch.no_grad()
def batch_logprobs(model, token_lists, device, pad_id, max_len):
"""Total log-prob of each sequence in a padded batch. token_lists: list[list[int]]."""
B = len(token_lists)
L = min(max(len(t) for t in token_lists), max_len)
inp = torch.full((B, L), pad_id, dtype=torch.long, device=device)
lengths = []
for i, t in enumerate(token_lists):
t = t[:L]
inp[i, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)
lengths.append(len(t))
with torch.autocast(device_type=device.type, dtype=torch.bfloat16,
enabled=device.type == "cuda"):
logits, _ = model(inp)
logp = F.log_softmax(logits.float(), dim=-1)
targets = inp[:, 1:]
pred = logp[:, :-1, :]
tok_lp = pred.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
mask = torch.zeros_like(tok_lp)
for i, n in enumerate(lengths):
mask[i, : max(0, n - 1)] = 1.0
return (tok_lp * mask).sum(dim=1)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--checkpoint", required=True)
ap.add_argument("--batch_size", type=int, default=256)
ap.add_argument("--output", default="blimp_results.json")
args = ap.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tok = Tokenizer.from_file(TOKENIZER_PATH)
pad_id = tok.token_to_id("<|pad|>") or 0
ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
cfg = ckpt["cfg"]
cfg.attn_backend = "sdpa"
max_len = cfg.max_seq_len
model = IvmeConversate(cfg).to(device)
model.load_state_dict(ckpt["model"])
model.eval()
print(f"[blimp] model loaded: {model.num_params()/1e6:.1f}M on {device}")
print("[blimp] loading full BLiMP dataset (one download)...")
full_ds = load_dataset("WillHeld/blimp", split="train")
by_task = {t: {"good": [], "bad": []} for t in BLIMP_TASKS}
for row in full_ds:
uid = row["UID"]
if uid in by_task:
by_task[uid]["good"].append(row["sentence_good"])
by_task[uid]["bad"].append(row["sentence_bad"])
print(f"[blimp] {len(full_ds)} examples bucketed into {len(BLIMP_TASKS)} subtasks\n")
results = {}
total_correct = total_examples = 0
for i, task in enumerate(BLIMP_TASKS):
goods = by_task[task]["good"]
bads = by_task[task]["bad"]
good_tok = [tok.encode(s).ids for s in goods]
bad_tok = [tok.encode(s).ids for s in bads]
correct = 0
for start in range(0, len(good_tok), args.batch_size):
gb = good_tok[start : start + args.batch_size]
bb = bad_tok[start : start + args.batch_size]
g_lp = batch_logprobs(model, gb, device, pad_id, max_len)
b_lp = batch_logprobs(model, bb, device, pad_id, max_len)
correct += (g_lp > b_lp).sum().item()
acc = correct / len(goods)
results[task] = acc
total_correct += correct
total_examples += len(goods)
running = total_correct / total_examples
print(f"[{i+1:02d}/{len(BLIMP_TASKS)}] {task:<55} {acc*100:5.1f}% "
f"(avg: {running*100:.2f}%)")
final = total_correct / total_examples
print(f"\n{'='*60}")
print(f" BLiMP average: {final*100:.2f}% (random baseline: 50%)")
print(f"{'='*60}")
with open(args.output, "w") as f:
json.dump({"tasks": results, "average": final}, f, indent=2)
print(f"\n[blimp] saved -> {args.output}")
if __name__ == "__main__":
main()