theapemachine commited on
Commit
4c1ba64
·
verified ·
1 Parent(s): c7bee4d

Add benchmark harness: scoring.py

Browse files
Files changed (1) hide show
  1. benchmark/scoring.py +168 -0
benchmark/scoring.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Scoring utilities for the Cortex benchmark harness.
3
+
4
+ Two evaluation modes:
5
+ 1. Log-likelihood scoring: For multiple-choice tasks (HellaSwag, ARC, PIQA, etc.)
6
+ Computes the average log-probability the model assigns to each continuation.
7
+
8
+ 2. Generation scoring: For free-form generation tasks (passkey retrieval, etc.)
9
+ Generates text and checks against expected patterns.
10
+ """
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from typing import List, Optional, Tuple, Dict
15
+ import re
16
+
17
+
18
+ @torch.no_grad()
19
+ def log_likelihood_score(
20
+ model,
21
+ tokenizer,
22
+ context: str,
23
+ continuations: List[str],
24
+ device: str = "cuda",
25
+ ) -> List[float]:
26
+ """
27
+ Compute normalized log-likelihood for each continuation given a context.
28
+
29
+ For each (context, continuation) pair:
30
+ 1. Tokenize context + continuation together
31
+ 2. Run forward pass to get logits
32
+ 3. Compute average log-prob over the continuation tokens only
33
+
34
+ Args:
35
+ model: The language model
36
+ tokenizer: The tokenizer
37
+ context: The prompt/context string
38
+ continuations: List of possible continuations to score
39
+ device: Device to use
40
+
41
+ Returns:
42
+ List of normalized log-likelihood scores (higher = model prefers this continuation)
43
+ """
44
+ scores = []
45
+
46
+ for cont in continuations:
47
+ # Tokenize context and full sequence separately to find where continuation starts
48
+ ctx_ids = tokenizer.encode(context, add_special_tokens=False)
49
+ full_text = context + cont
50
+ full_ids = tokenizer.encode(full_text, add_special_tokens=False)
51
+
52
+ # The continuation tokens start after the context tokens
53
+ cont_start = len(ctx_ids)
54
+ cont_length = len(full_ids) - cont_start
55
+
56
+ if cont_length <= 0:
57
+ scores.append(float("-inf"))
58
+ continue
59
+
60
+ # Forward pass
61
+ input_ids = torch.tensor([full_ids], device=device)
62
+
63
+ # Truncate if too long for model
64
+ max_len = getattr(model.config, "max_position_embeddings", 2048)
65
+ if input_ids.shape[1] > max_len:
66
+ input_ids = input_ids[:, :max_len]
67
+ cont_length = min(cont_length, max_len - cont_start)
68
+ if cont_length <= 0:
69
+ scores.append(float("-inf"))
70
+ continue
71
+
72
+ outputs = model(input_ids)
73
+ logits = outputs.logits # [1, seq_len, vocab_size]
74
+
75
+ # Shift: logits[i] predicts token[i+1]
76
+ # For continuation tokens at positions [cont_start, cont_start+cont_length),
77
+ # we need logits at positions [cont_start-1, cont_start+cont_length-1)
78
+ shift_logits = logits[0, cont_start - 1 : cont_start + cont_length - 1, :]
79
+ shift_labels = input_ids[0, cont_start : cont_start + cont_length]
80
+
81
+ # Log-probabilities
82
+ log_probs = F.log_softmax(shift_logits, dim=-1)
83
+ token_log_probs = log_probs.gather(1, shift_labels.unsqueeze(1)).squeeze(1)
84
+
85
+ # Normalize by continuation length (average log-prob per token)
86
+ avg_log_prob = token_log_probs.mean().item()
87
+ scores.append(avg_log_prob)
88
+
89
+ return scores
90
+
91
+
92
+ @torch.no_grad()
93
+ def generate_and_check(
94
+ model,
95
+ tokenizer,
96
+ prompt: str,
97
+ expected: str,
98
+ max_new_tokens: int = 64,
99
+ device: str = "cuda",
100
+ exact_match: bool = False,
101
+ ) -> Tuple[bool, str]:
102
+ """
103
+ Generate text and check if the expected answer appears in the output.
104
+
105
+ Args:
106
+ model: The language model
107
+ tokenizer: The tokenizer
108
+ prompt: The input prompt
109
+ expected: The expected answer string
110
+ max_new_tokens: Max tokens to generate
111
+ device: Device
112
+ exact_match: If True, requires exact match; otherwise substring match
113
+
114
+ Returns:
115
+ (is_correct, generated_text)
116
+ """
117
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
118
+
119
+ # Pad token
120
+ pad_token_id = tokenizer.pad_token_id
121
+ if pad_token_id is None:
122
+ pad_token_id = tokenizer.eos_token_id
123
+
124
+ output_ids = model.generate(
125
+ **inputs,
126
+ max_new_tokens=max_new_tokens,
127
+ do_sample=False,
128
+ temperature=1.0,
129
+ pad_token_id=pad_token_id,
130
+ )
131
+
132
+ # Decode only the new tokens
133
+ new_tokens = output_ids[0, inputs["input_ids"].shape[1]:]
134
+ generated = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
135
+
136
+ if exact_match:
137
+ is_correct = generated.strip().lower() == expected.strip().lower()
138
+ else:
139
+ is_correct = expected.strip().lower() in generated.lower()
140
+
141
+ return is_correct, generated
142
+
143
+
144
+ def accuracy_from_loglikelihoods(
145
+ scores_per_example: List[Tuple[List[float], int]],
146
+ ) -> Dict[str, float]:
147
+ """
148
+ Compute accuracy from log-likelihood scores.
149
+
150
+ Args:
151
+ scores_per_example: List of (scores_for_each_choice, correct_index)
152
+
153
+ Returns:
154
+ Dict with accuracy and count metrics
155
+ """
156
+ correct = 0
157
+ total = len(scores_per_example)
158
+
159
+ for scores, gold_idx in scores_per_example:
160
+ predicted = max(range(len(scores)), key=lambda i: scores[i])
161
+ if predicted == gold_idx:
162
+ correct += 1
163
+
164
+ return {
165
+ "accuracy": correct / total if total > 0 else 0.0,
166
+ "correct": correct,
167
+ "total": total,
168
+ }