DECADE / evaluate_qa.py
anonymous-penguin's picture
Initial code release
9c60174 verified
import argparse
import json
import os
import re
import time
import numpy as np
from openai import OpenAI
from tqdm import tqdm
try:
from openai import AzureOpenAI
from azure.identity import (
AzureCliCredential,
ChainedTokenCredential,
ManagedIdentityCredential,
get_bearer_token_provider,
)
AZURE_OAUTH_SCOPE = os.environ.get("AZURE_OAUTH_SCOPE", "")
if AZURE_OAUTH_SCOPE:
credential = get_bearer_token_provider(
ChainedTokenCredential(
AzureCliCredential(),
ManagedIdentityCredential(),
),
AZURE_OAUTH_SCOPE,
)
else:
credential = None
except ImportError:
AzureOpenAI = None
credential = None
from model_zoo import model_zoo
# Azure OpenAI endpoint (set AZURE_OPENAI_ENDPOINT env var to your deployment URL).
endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT", "")
# OpenAI-compatible LiteLLM proxy URL (set LITELLM_BASE_URL env var to your proxy).
TRITONAI_BASE_URL = os.environ.get("LITELLM_BASE_URL", "")
ATOMIC_PROMPT_VERSION = "atomic-v1"
LEGACY_PROMPT_VERSION = "binary-v0"
ATOM_SCORES = {
"correct": 1.0,
"partially_correct": 0.5,
"missing": 0.0,
"incorrect": 0.0,
}
def _retryable_status(e) -> "int | None":
status = getattr(e, "status_code", None) or getattr(e, "http_status", None)
if status in (429, 500, 503, 403):
return status
resp = getattr(e, "response", None)
if resp is not None and getattr(resp, "status_code", None) in (429, 500, 503):
return resp.status_code
msg = str(e).lower()
if "429" in msg or "rate limit" in msg:
return 429
if "500" in msg or "internal server error" in msg:
return 500
if "503" in msg or "api configuration unavailable" in msg:
return 503
return None
def parse_json_object(text):
text = (text or "").strip()
if text.startswith("```"):
text = re.sub(r"^```(?:json)?", "", text).strip()
text = re.sub(r"```$", "", text).strip()
try:
return json.loads(text)
except json.JSONDecodeError:
start = text.find("{")
end = text.rfind("}") + 1
if start >= 0 and end > start:
return json.loads(text[start:end])
raise
def sanitize_model_name(name):
return re.sub(r"[^A-Za-z0-9_.-]+", "_", name)
def read_json_or_jsonl(path):
try:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
except json.JSONDecodeError:
with open(path, "r", encoding="utf-8") as f:
return [json.loads(line) for line in f if line.strip()]
def read_existing_jsonl(path):
if not path or not os.path.exists(path):
return {}
rows = {}
with open(path, "r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
obj = json.loads(line)
if "question_id" in obj:
rows[obj["question_id"]] = obj
return rows
def write_json(path, obj):
tmp_path = path + ".tmp"
with open(tmp_path, "w", encoding="utf-8") as f:
json.dump(obj, f, ensure_ascii=False, indent=2)
f.write("\n")
os.replace(tmp_path, path)
def append_jsonl(path, row):
with open(path, "a", encoding="utf-8") as f:
print(json.dumps(row, ensure_ascii=False), file=f, flush=True)
def default_result_file(hyp_file, metric_model, eval_mode):
if eval_mode == "legacy":
return f"{hyp_file}.eval-results-{metric_model}"
model_tag = sanitize_model_name(metric_model)
return f"{hyp_file}.eval-results-{model_tag}-{ATOMIC_PROMPT_VERSION}-{eval_mode}"
def default_rubric_file(ref_file):
return f"{ref_file}.{ATOMIC_PROMPT_VERSION}.rubric.json"
def load_rubric_file(path, ref_file):
if not path or not os.path.exists(path):
return {
"prompt_version": ATOMIC_PROMPT_VERSION,
"source_ref_file": ref_file,
"rubrics": {},
}
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
if "rubrics" in data:
data.setdefault("prompt_version", ATOMIC_PROMPT_VERSION)
data.setdefault("source_ref_file", ref_file)
return data
return {
"prompt_version": ATOMIC_PROMPT_VERSION,
"source_ref_file": ref_file,
"rubrics": data,
}
def question_type_guidance(task):
if task in ["Information Absence"]:
return (
"The correct answer is that the information is unavailable, absent, "
"not yet known, or not supported. A response that gives a concrete "
"answer instead of abstaining is incorrect."
)
if task in ["Aggregation", "single-session-user", "single-session-assistant", "multi-session"]:
return (
"Check every requested item, count, list member, and named fact. "
"Exact counts and required list coverage matter. Extra material "
"facts that change the answer should be flagged."
)
if task in ["Aggregation + Temporal"]:
return (
"Check both the aggregated facts and their time/order associations. "
"An answer can name the right item but still be wrong if the timing, "
"ordering, before/after relation, or year is wrong."
)
if task in ["Temporal Reasoning", "temporal-reasoning"]:
return (
"Check the specific time, date, year, sequence, duration, or temporal "
"relationship asked for. Accept +/-1 only for day/week/month duration "
"counts, not for years, event identity, or ordering."
)
if task in ["Knowledge Update", "knowledge-update"]:
return (
"Check the current or most recent state. Historical context is fine "
"only if the final/current state is clearly correct. Outdated states "
"presented as current are incorrect."
)
if task == "single-session-preference":
return (
"Check whether the response recalls and applies the stated preference. "
"Do not require unnecessary verbosity, but contradictions are incorrect."
)
return "Check whether the model response satisfies all required facts in the reference answer."
def get_anscheck_prompt(task, question, answer, response):
"""Legacy binary yes/no prompt kept for backward compatibility."""
if task in ["Information Absence"]:
template = """You are evaluating whether a model correctly identified that a question is unanswerable given the available personal chat history.
Question: {question}
Explanation of why it is unanswerable: {answer}
Model Response: {response}
Evaluation criteria:
- CORRECT if the model explicitly states that the information is not available, insufficient, or that the question cannot be answered from the provided context.
- INCORRECT if the model fabricates an answer or fails to acknowledge the unanswerable nature of the question.
- The model does not need to use the exact word "unanswerable" -- expressing uncertainty or lack of information is sufficient.
Briefly explain your reasoning (1-2 sentences), then on the last line write only: yes or no"""
elif task in ["Aggregation", "single-session-user", "single-session-assistant", "multi-session"]:
template = """You are evaluating whether a model correctly answered a question that requires aggregating specific facts from a user's personal chat history.
Question: {question}
Reference Answer: {answer}
Model Response: {response}
Evaluation criteria:
- CORRECT if the response identifies all key items, facts, or counts present in the reference answer, even if phrased differently or with added context.
- INCORRECT if the response:
- States a wrong count (e.g., says "5" when the answer is "3")
- Omits one or more key items/facts listed in the reference
- Lists mostly wrong items even if the count is right
- Partial answers that cover only a subset of required items are INCORRECT.
- Verbose responses are acceptable as long as all reference items are present within them.
- If the response contains correct items but also lists additional plausible-sounding but unverified items beyond the reference, this does NOT make it incorrect -- evaluate only whether the reference items are covered.
- Semantic equivalence counts as correct (e.g., "RSI" = "Repetitive Strain Injury").
Briefly explain your reasoning (2-3 sentences), then on the last line write only: yes or no"""
elif task in ["Aggregation + Temporal"]:
template = """You are evaluating whether a model correctly answered a question that requires both aggregating facts AND reasoning about their temporal order or time associations from a user's personal chat history.
Question: {question}
Reference Answer: {answer}
Model Response: {response}
Evaluation criteria:
- CORRECT if the response captures both:
(a) all key events/facts listed in the reference, and
(b) their correct temporal associations (ordering, time periods, or "before/after" relationships).
- INCORRECT if the response:
- Omits one or more key events or facts from the reference
- Gets the temporal ordering or time associations wrong
- Captures only the content without the temporal aspects, or vice versa
- Responses that describe the correct progression/sequence in different words are acceptable.
- Partial answers covering only some events or ignoring time aspects are INCORRECT.
- Minor wording differences or additional explanatory context are acceptable.
Briefly explain your reasoning (2-3 sentences), then on the last line write only: yes or no"""
elif task in ["Temporal Reasoning", "temporal-reasoning"]:
template = """You are evaluating whether a model correctly answered a question about temporal relationships in a user's personal chat history.
Question: {question}
Reference Answer: {answer}
Model Response: {response}
Evaluation criteria:
- CORRECT if the response correctly identifies the specific time, date, year, sequence, or temporal relationship asked about.
- INCORRECT if the response states a wrong year, wrong sequence, wrong temporal relationship, or misidentifies which event came first/last.
- Off-by-one tolerance: if the question asks for a count of days, weeks, or months, accept answers that differ by +/-1. This tolerance does NOT apply to years or to identifying specific events/artifacts.
- Responses that correctly identify the fact but with verbose context are acceptable.
- If the response hedges but still states the correct answer, it is correct.
Briefly explain your reasoning (2-3 sentences), then on the last line write only: yes or no"""
elif task in ["Knowledge Update", "knowledge-update"]:
template = """You are evaluating whether a model correctly answered a question about the most recent or current state of something that changed over time in a user's personal chat history.
Question: {question}
Reference Answer: {answer}
Model Response: {response}
Evaluation criteria:
- CORRECT if the response correctly identifies the most recent/current state as described in the reference answer.
- The response may include earlier historical states as context -- this is acceptable as long as the current/final state is correctly identified and clearly stated.
- INCORRECT if the response:
- States an outdated/superseded state as the current one
- Omits the current state entirely
- Correctly describes history but draws the wrong conclusion about what the current state is
- Semantic equivalence counts (e.g., "flexitarian" and "semi-vegetarian diet with occasional meat" are equivalent if contextually clear).
Briefly explain your reasoning (2-3 sentences), then on the last line write only: yes or no"""
elif task == "single-session-preference":
template = """You are evaluating whether a model correctly answered a personalized question based on a user's stated preferences from their chat history.
Question: {question}
Reference Rubric: {answer}
Model Response: {response}
Evaluation criteria:
- CORRECT if the response recalls and applies the user's personal preferences correctly, even if not covering every point in the rubric.
- INCORRECT if the response ignores, contradicts, or misremembers the user's preferences.
Briefly explain your reasoning (1-2 sentences), then on the last line write only: yes or no"""
else:
template = """You are evaluating whether a model's response correctly answers a question based on a user's personal chat history.
Question: {question}
Reference Answer: {answer}
Model Response: {response}
Is the response correct? It is correct if it contains all key information from the reference answer, even if phrased differently.
Briefly explain your reasoning (1-2 sentences), then on the last line write only: yes or no"""
return template.format(question=question, answer=answer, response=response)
def build_rubric_prompt(task, question, answer):
return f"""You are creating an atomic grading rubric for an open-ended QA benchmark.
Question type: {task}
Question:
{question}
Reference answer:
{answer}
Question-type guidance:
{question_type_guidance(task)}
Decompose the reference answer into the smallest independently checkable requirements needed to answer the question.
Rules:
- Each atom should be a single required answer unit: an entity, count, date/year, order relation, current-state conclusion, or abstention requirement.
- If an entity and its temporal relation are inseparable for correctness, keep them in the same atom.
- For list/count questions, include an atom for the exact count when the question asks "how many", and atoms for each required listed item when the item identities matter.
- For Information Absence, usually use one atom requiring the response to clearly state that the information is unavailable/insufficient/not discussed, and add a strict note that concrete fabricated answers are wrong.
- Do not include supporting evidence requirements or session IDs unless the question explicitly asks for them.
- Weights should normally be 1.0. Use a higher weight only when one atom is clearly the main answer and other atoms are minor.
Return JSON only with this schema:
{{
"required_atoms": [
{{
"id": "a1",
"requirement": "short, specific grading requirement",
"weight": 1.0
}}
],
"strict_notes": ["short note about exactness, ordering, abstention, or hallucination handling"]
}}"""
def build_atomic_eval_prompt(task, question, answer, response, rubric):
rubric_str = json.dumps(
{
"required_atoms": rubric["required_atoms"],
"strict_notes": rubric.get("strict_notes", []),
},
ensure_ascii=False,
indent=2,
)
return f"""You are an LLM-as-a-judge evaluating one model response against a reference answer.
Question type: {task}
Question:
{question}
Reference answer:
{answer}
Model response:
{response}
Atomic grading rubric:
{rubric_str}
Question-type guidance:
{question_type_guidance(task)}
Judge each atom independently.
Atom labels:
- correct: the response fully satisfies this atom, allowing semantic paraphrase.
- partially_correct: the response gets the main idea but is incomplete or slightly underspecified. Use this sparingly. Do not use it for wrong counts, wrong years, wrong named entities, wrong ordering, or a concrete answer to an Information Absence question.
- missing: the response does not address this atom.
- incorrect: the response contradicts this atom or gives the wrong count/entity/date/order/current state.
Also identify unsupported_or_contradictory material:
- severity "material": extra answer content that changes the final answer, adds extra items to an exact list/count, gives an outdated current state, fabricates a concrete answer for Information Absence, or contradicts any atom.
- severity "minor": harmless context or extra explanation that does not change the answer.
Return JSON only with this schema:
{{
"atom_judgments": [
{{
"id": "a1",
"label": "correct|partially_correct|missing|incorrect",
"rationale": "brief reason"
}}
],
"unsupported_or_contradictory": [
{{
"text": "extra or contradictory claim",
"severity": "minor|material",
"rationale": "brief reason"
}}
],
"absence_mismatch": false,
"overall_rationale": "one or two sentence summary"
}}"""
def llm_call(
deployment_name: str,
api_version: str,
_prompt: str,
debug: bool = False,
vllm: bool = False,
tritonai: bool = False,
nvidia: bool = False,
):
if nvidia:
client = OpenAI(
api_key=os.getenv("NV_API_KEY"),
base_url="https://inference-api.nvidia.com",
)
while True:
try:
return client.chat.completions.create(
model=deployment_name,
messages=[{"role": "user", "content": _prompt}],
)
except Exception as e:
st = _retryable_status(e)
if st in (429, 500, 503, 403):
print(f"[WARN] HTTP {st} from NVIDIA API; sleeping 60s then retrying...", flush=True)
time.sleep(60)
continue
print("One exception captured", repr(e), flush=True)
raise
if tritonai:
client = OpenAI(
api_key=os.getenv("TRITONAI_API_KEY"),
base_url=TRITONAI_BASE_URL,
)
while True:
try:
return client.chat.completions.create(
model=deployment_name,
messages=[{"role": "user", "content": _prompt}],
)
except Exception as e:
st = _retryable_status(e)
if st in (429, 500, 503, 403):
print(f"[WARN] HTTP {st} from LiteLLM proxy; sleeping 60s then retrying...", flush=True)
time.sleep(60)
continue
print("One exception captured", repr(e), flush=True)
raise
if deployment_name.startswith("claude-"):
import anthropic
client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
while True:
try:
msg = client.messages.create(
model=deployment_name,
max_tokens=1024,
messages=[{"role": "user", "content": _prompt}],
)
class _Choice:
class _Msg:
def __init__(self, text):
self.content = text
def __init__(self, text):
self.message = self._Msg(text)
class _Completion:
def __init__(self, text):
self.choices = [_Choice(text)]
return _Completion(msg.content[0].text)
except Exception as e:
st = _retryable_status(e)
if st in (429, 500, 503, 403):
print(f"[WARN] HTTP {st} from Anthropic; sleeping 60s then retrying...", flush=True)
time.sleep(60)
continue
print("One exception captured", repr(e), flush=True)
raise
if vllm:
client = OpenAI(
base_url=os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1"),
api_key=os.getenv("VLLM_API_KEY", "EMPTY"),
)
elif debug:
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
else:
client = AzureOpenAI(
azure_endpoint=endpoint,
azure_ad_token_provider=credential,
api_version=api_version,
)
kwargs = {
"model": deployment_name,
"messages": [{"role": "system", "content": _prompt}],
}
while True:
try:
return client.chat.completions.create(**kwargs)
except Exception as e:
st = _retryable_status(e)
if st in (429, 500, 503, 403):
print(f"[WARN] HTTP {st} from LLM; sleeping 120s then retrying...", flush=True)
time.sleep(120)
continue
print("One exception captured", repr(e), flush=True)
raise
def call_json_llm(prompt, deployment_name, api_version, args, max_retries=3):
last_error = None
for attempt in range(max_retries):
completion = llm_call(
deployment_name,
api_version,
prompt,
debug=args.debug,
vllm=args.vllm,
tritonai=args.tritonai,
nvidia=args.nvidia,
)
content = completion.choices[0].message.content.strip()
try:
return parse_json_object(content), content
except Exception as e:
last_error = e
if attempt < max_retries - 1:
print(f"[WARN] Failed to parse judge JSON; retrying ({attempt + 1}/{max_retries})", flush=True)
time.sleep(2)
raise ValueError(f"Failed to parse JSON response from judge: {last_error}")
def fallback_rubric(qid, task, question, answer):
return {
"question_id": qid,
"question_type": task,
"question": question,
"reference_answer": answer,
"required_atoms": [
{
"id": "a1",
"requirement": f"Response must correctly answer the question according to the reference answer: {answer}",
"weight": 1.0,
}
],
"strict_notes": ["Fallback single-atom rubric produced because the generated rubric was invalid."],
"prompt_version": ATOMIC_PROMPT_VERSION,
}
def normalize_rubric(qid, task, question, answer, parsed):
atoms = parsed.get("required_atoms", []) if isinstance(parsed, dict) else []
norm_atoms = []
for idx, atom in enumerate(atoms, start=1):
if not isinstance(atom, dict):
continue
requirement = str(atom.get("requirement", "")).strip()
if not requirement:
continue
atom_id = str(atom.get("id", f"a{idx}")).strip() or f"a{idx}"
try:
weight = float(atom.get("weight", 1.0))
except (TypeError, ValueError):
weight = 1.0
if weight <= 0:
weight = 1.0
norm_atoms.append({"id": atom_id, "requirement": requirement, "weight": weight})
if not norm_atoms:
return fallback_rubric(qid, task, question, answer)
strict_notes = parsed.get("strict_notes", [])
if not isinstance(strict_notes, list):
strict_notes = [str(strict_notes)]
return {
"question_id": qid,
"question_type": task,
"question": question,
"reference_answer": answer,
"required_atoms": norm_atoms,
"strict_notes": [str(x) for x in strict_notes],
"prompt_version": ATOMIC_PROMPT_VERSION,
}
def get_or_build_rubric(qdata, rubric_data, rubric_file, deployment_name, api_version, args):
qid = qdata["question_id"]
existing = rubric_data["rubrics"].get(qid)
if (
existing
and existing.get("prompt_version") == ATOMIC_PROMPT_VERSION
and existing.get("required_atoms")
and not args.force_rebuild_rubric
):
return existing
task = qdata["question_type"]
prompt = build_rubric_prompt(task, qdata["question"], qdata["answer"])
try:
parsed, raw = call_json_llm(prompt, deployment_name, api_version, args)
rubric = normalize_rubric(qid, task, qdata["question"], qdata["answer"], parsed)
rubric["rubric_raw_response"] = raw
except Exception as e:
print(f"[WARN] Falling back to single-atom rubric for {qid}: {e}", flush=True)
rubric = fallback_rubric(qid, task, qdata["question"], qdata["answer"])
rubric_data["rubrics"][qid] = rubric
write_json(rubric_file, rubric_data)
return rubric
def compute_atomic_scores(rubric, parsed):
atoms = rubric.get("required_atoms", [])
judgments_by_id = {}
raw_judgments = parsed.get("atom_judgments", []) if isinstance(parsed, dict) else []
if isinstance(raw_judgments, list):
for judgment in raw_judgments:
if not isinstance(judgment, dict):
continue
atom_id = str(judgment.get("id", "")).strip()
label = str(judgment.get("label", "")).strip()
if label not in ATOM_SCORES:
label = "incorrect"
judgments_by_id[atom_id] = {
"id": atom_id,
"label": label,
"score": ATOM_SCORES[label],
"rationale": str(judgment.get("rationale", "")),
}
norm_judgments = []
weighted_score = 0.0
total_weight = 0.0
for atom in atoms:
atom_id = atom["id"]
weight = float(atom.get("weight", 1.0))
judgment = judgments_by_id.get(
atom_id,
{"id": atom_id, "label": "missing", "score": 0.0, "rationale": "No judgment returned."},
)
judgment["requirement"] = atom["requirement"]
judgment["weight"] = weight
norm_judgments.append(judgment)
weighted_score += judgment["score"] * weight
total_weight += weight
extras = parsed.get("unsupported_or_contradictory", []) if isinstance(parsed, dict) else []
if not isinstance(extras, list):
extras = []
material_extras = [
x for x in extras
if isinstance(x, dict) and str(x.get("severity", "")).strip() == "material"
]
absence_mismatch = bool(parsed.get("absence_mismatch", False)) if isinstance(parsed, dict) else False
strict_label = (
bool(norm_judgments)
and all(j["label"] == "correct" for j in norm_judgments)
and not material_extras
and not absence_mismatch
)
partial_score = weighted_score / total_weight if total_weight > 0 else 0.0
if absence_mismatch:
partial_score = 0.0
elif material_extras and partial_score > 0.8:
partial_score = 0.8
return {
"strict_label": strict_label,
"partial_score": round(partial_score, 4),
"atom_judgments": norm_judgments,
"unsupported_or_contradictory": extras,
"absence_mismatch": absence_mismatch,
"overall_rationale": str(parsed.get("overall_rationale", "")) if isinstance(parsed, dict) else "",
}
def judge_atomic(qdata, hypothesis, rubric, deployment_name, api_version, args):
prompt = build_atomic_eval_prompt(
qdata["question_type"],
qdata["question"],
qdata["answer"],
hypothesis,
rubric,
)
parsed, raw = call_json_llm(prompt, deployment_name, api_version, args)
scores = compute_atomic_scores(rubric, parsed)
return {
"model": args.eval_model_name,
"prompt_version": ATOMIC_PROMPT_VERSION,
"eval_mode": args.eval_mode,
"strict_label": scores["strict_label"],
"partial_score": scores["partial_score"],
"required_atoms": rubric["required_atoms"],
"atom_judgments": scores["atom_judgments"],
"unsupported_or_contradictory": scores["unsupported_or_contradictory"],
"absence_mismatch": scores["absence_mismatch"],
"overall_rationale": scores["overall_rationale"],
"raw_response": raw,
}
def should_skip_existing(existing_row, eval_mode):
if eval_mode == "legacy":
return "autoeval_label" in existing_row
atomic = existing_row.get("autoeval_atomic")
return bool(atomic and atomic.get("prompt_version") == ATOMIC_PROMPT_VERSION)
def safe_mean(values):
if not values:
return float("nan")
return float(np.mean(values))
def print_legacy_summary(logs, qtype2acc):
labels = [1 if x["autoeval_label"]["label"] else 0 for x in logs if "autoeval_label" in x]
print("Accuracy:", round(safe_mean(labels), 4))
for k, v in sorted(qtype2acc.items()):
print("\t{}: {} ({})".format(k, round(safe_mean(v), 4), len(v)))
def print_atomic_summary(logs, qtype2strict, qtype2partial, eval_mode):
strict_values = [
1 if x["autoeval_atomic"]["strict_label"] else 0
for x in logs
if "autoeval_atomic" in x
]
partial_values = [
float(x["autoeval_atomic"]["partial_score"])
for x in logs
if "autoeval_atomic" in x
]
if eval_mode in ("strict", "both"):
print("Strict Accuracy:", round(safe_mean(strict_values), 4))
for k, v in sorted(qtype2strict.items()):
print("\t{}: {} ({})".format(k, round(safe_mean(v), 4), len(v)))
if eval_mode in ("partial", "both"):
print("Partial Score:", round(safe_mean(partial_values), 4))
for k, v in sorted(qtype2partial.items()):
print("\t{}: {} ({})".format(k, round(safe_mean(v), 4), len(v)))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--hyp_file", type=str, default=None)
parser.add_argument("--ref_file", type=str, required=True)
parser.add_argument("--eval_model_name", type=str, required=True)
parser.add_argument(
"--eval_mode",
type=str,
default="both",
choices=["legacy", "strict", "partial", "both"],
help="legacy uses the old yes/no judge; strict/partial/both use atomic JSON judging.",
)
parser.add_argument("--rubric_file", type=str, default=None)
parser.add_argument("--build_rubric_only", action="store_true", default=False)
parser.add_argument("--force_rebuild_rubric", action="store_true", default=False)
parser.add_argument("--result_file", type=str, default=None)
parser.add_argument("--debug", action="store_true", default=False)
parser.add_argument("--vllm", action="store_true", default=False)
parser.add_argument("--tritonai", action="store_true", default=False,
help="Use OpenAI-compatible LiteLLM proxy (set TRITONAI_API_KEY env var)")
parser.add_argument("--nvidia", action="store_true", default=False,
help="Use NVIDIA inference API (set NV_API_KEY env var)")
parser.add_argument("--verbose", action=argparse.BooleanOptionalAction, default=True)
args = parser.parse_args()
if not args.build_rubric_only and not args.hyp_file:
parser.error("--hyp_file is required unless --build_rubric_only is set")
metric_model = args.eval_model_name
deployment_name, api_version = model_zoo[metric_model]
references = read_json_or_jsonl(args.ref_file)
qid2qdata = {entry["question_id"]: entry for entry in references}
qid2qtype = {entry["question_id"]: entry["question_type"] for entry in references}
qtypes = set(qid2qtype.values())
rubric_data = None
rubric_file = args.rubric_file or default_rubric_file(args.ref_file)
if args.eval_mode != "legacy" or args.build_rubric_only:
rubric_data = load_rubric_file(rubric_file, args.ref_file)
if args.build_rubric_only:
for entry in tqdm(references, desc="building rubrics"):
get_or_build_rubric(entry, rubric_data, rubric_file, deployment_name, api_version, args)
print(f"Saved rubric file to {rubric_file}")
return
result_file = args.result_file or default_result_file(args.hyp_file, metric_model, args.eval_mode)
existing = read_existing_jsonl(result_file)
hypotheses = read_json_or_jsonl(args.hyp_file)
qtype2acc = {t: [] for t in qtypes}
qtype2strict = {t: [] for t in qtypes}
qtype2partial = {t: [] for t in qtypes}
logs = []
for entry in tqdm(hypotheses):
qid = entry.get("question_id")
if qid not in qid2qtype:
if qid is not None:
print(f"Warning: skipping {qid} as it is not in reference data.")
continue
if qid in existing and should_skip_existing(existing[qid], args.eval_mode):
existing_row = existing[qid]
logs.append(existing_row)
qtype = qid2qtype[qid]
if args.eval_mode == "legacy":
label = existing_row["autoeval_label"]["label"]
qtype2acc[qtype].append(1 if label else 0)
else:
atomic = existing_row["autoeval_atomic"]
qtype2strict[qtype].append(1 if atomic["strict_label"] else 0)
qtype2partial[qtype].append(float(atomic["partial_score"]))
continue
qdata = qid2qdata[qid]
qtype = qdata["question_type"]
hyp = entry["hypothesis"]
if args.eval_mode == "legacy":
prompt = get_anscheck_prompt(qtype, qdata["question"], qdata["answer"], hyp)
completion = llm_call(
deployment_name,
api_version,
prompt,
debug=args.debug,
vllm=args.vllm,
tritonai=args.tritonai,
nvidia=args.nvidia,
)
eval_response = completion.choices[0].message.content.strip()
last_line = next((l.strip().lower() for l in reversed(eval_response.splitlines()) if l.strip()), "")
label = last_line == "yes" or last_line.startswith("yes")
row = dict(entry)
row["autoeval_label"] = {
"model": metric_model,
"prompt_version": LEGACY_PROMPT_VERSION,
"label": label,
"raw_response": eval_response,
}
logs.append(row)
qtype2acc[qtype].append(1 if label else 0)
if args.verbose:
print(json.dumps({
"question": qdata["question"],
"answer": qdata["answer"],
"hypothesis": hyp,
"autoeval_label": label,
}, indent=4), flush=True)
append_jsonl(result_file, row)
continue
rubric = get_or_build_rubric(qdata, rubric_data, rubric_file, deployment_name, api_version, args)
try:
atomic_eval = judge_atomic(qdata, hyp, rubric, deployment_name, api_version, args)
except ValueError as _judge_err:
print(f"[WARN] judge_atomic failed for {qdata['question_id']}, writing zero score: {_judge_err}", flush=True)
atoms = rubric.get("required_atoms", [])
atomic_eval = {
"model": deployment_name,
"prompt_version": ATOMIC_PROMPT_VERSION,
"eval_mode": args.eval_mode,
"strict_label": False,
"partial_score": 0.0,
"required_atoms": atoms,
"atom_judgments": [{"id": a["id"], "label": "error", "score": 0.0, "rationale": "judge parse error", "requirement": a.get("requirement", ""), "weight": a.get("weight", 1.0)} for a in atoms],
"unsupported_or_contradictory": [],
"absence_mismatch": False,
"overall_rationale": f"Skipped: judge JSON parse error ({_judge_err})",
}
row = dict(entry)
row["autoeval_atomic"] = atomic_eval
logs.append(row)
qtype2strict[qtype].append(1 if atomic_eval["strict_label"] else 0)
qtype2partial[qtype].append(float(atomic_eval["partial_score"]))
if args.verbose:
print(json.dumps({
"question": qdata["question"],
"answer": qdata["answer"],
"hypothesis": hyp,
"strict_label": atomic_eval["strict_label"],
"partial_score": atomic_eval["partial_score"],
"atom_judgments": atomic_eval["atom_judgments"],
}, indent=4), flush=True)
append_jsonl(result_file, row)
if args.eval_mode == "legacy":
print_legacy_summary(logs, qtype2acc)
else:
print_atomic_summary(logs, qtype2strict, qtype2partial, args.eval_mode)
print(f"Rubric file: {rubric_file}")
print("Saved to", result_file)
if __name__ == "__main__":
main()