import argparse import json import os import re import time import numpy as np from openai import OpenAI from tqdm import tqdm try: from openai import AzureOpenAI from azure.identity import ( AzureCliCredential, ChainedTokenCredential, ManagedIdentityCredential, get_bearer_token_provider, ) AZURE_OAUTH_SCOPE = os.environ.get("AZURE_OAUTH_SCOPE", "") if AZURE_OAUTH_SCOPE: credential = get_bearer_token_provider( ChainedTokenCredential( AzureCliCredential(), ManagedIdentityCredential(), ), AZURE_OAUTH_SCOPE, ) else: credential = None except ImportError: AzureOpenAI = None credential = None from model_zoo import model_zoo # Azure OpenAI endpoint (set AZURE_OPENAI_ENDPOINT env var to your deployment URL). endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT", "") # OpenAI-compatible LiteLLM proxy URL (set LITELLM_BASE_URL env var to your proxy). TRITONAI_BASE_URL = os.environ.get("LITELLM_BASE_URL", "") ATOMIC_PROMPT_VERSION = "atomic-v1" LEGACY_PROMPT_VERSION = "binary-v0" ATOM_SCORES = { "correct": 1.0, "partially_correct": 0.5, "missing": 0.0, "incorrect": 0.0, } def _retryable_status(e) -> "int | None": status = getattr(e, "status_code", None) or getattr(e, "http_status", None) if status in (429, 500, 503, 403): return status resp = getattr(e, "response", None) if resp is not None and getattr(resp, "status_code", None) in (429, 500, 503): return resp.status_code msg = str(e).lower() if "429" in msg or "rate limit" in msg: return 429 if "500" in msg or "internal server error" in msg: return 500 if "503" in msg or "api configuration unavailable" in msg: return 503 return None def parse_json_object(text): text = (text or "").strip() if text.startswith("```"): text = re.sub(r"^```(?:json)?", "", text).strip() text = re.sub(r"```$", "", text).strip() try: return json.loads(text) except json.JSONDecodeError: start = text.find("{") end = text.rfind("}") + 1 if start >= 0 and end > start: return json.loads(text[start:end]) raise def sanitize_model_name(name): return re.sub(r"[^A-Za-z0-9_.-]+", "_", name) def read_json_or_jsonl(path): try: with open(path, "r", encoding="utf-8") as f: return json.load(f) except json.JSONDecodeError: with open(path, "r", encoding="utf-8") as f: return [json.loads(line) for line in f if line.strip()] def read_existing_jsonl(path): if not path or not os.path.exists(path): return {} rows = {} with open(path, "r", encoding="utf-8") as f: for line in f: if not line.strip(): continue obj = json.loads(line) if "question_id" in obj: rows[obj["question_id"]] = obj return rows def write_json(path, obj): tmp_path = path + ".tmp" with open(tmp_path, "w", encoding="utf-8") as f: json.dump(obj, f, ensure_ascii=False, indent=2) f.write("\n") os.replace(tmp_path, path) def append_jsonl(path, row): with open(path, "a", encoding="utf-8") as f: print(json.dumps(row, ensure_ascii=False), file=f, flush=True) def default_result_file(hyp_file, metric_model, eval_mode): if eval_mode == "legacy": return f"{hyp_file}.eval-results-{metric_model}" model_tag = sanitize_model_name(metric_model) return f"{hyp_file}.eval-results-{model_tag}-{ATOMIC_PROMPT_VERSION}-{eval_mode}" def default_rubric_file(ref_file): return f"{ref_file}.{ATOMIC_PROMPT_VERSION}.rubric.json" def load_rubric_file(path, ref_file): if not path or not os.path.exists(path): return { "prompt_version": ATOMIC_PROMPT_VERSION, "source_ref_file": ref_file, "rubrics": {}, } with open(path, "r", encoding="utf-8") as f: data = json.load(f) if "rubrics" in data: data.setdefault("prompt_version", ATOMIC_PROMPT_VERSION) data.setdefault("source_ref_file", ref_file) return data return { "prompt_version": ATOMIC_PROMPT_VERSION, "source_ref_file": ref_file, "rubrics": data, } def question_type_guidance(task): if task in ["Information Absence"]: return ( "The correct answer is that the information is unavailable, absent, " "not yet known, or not supported. A response that gives a concrete " "answer instead of abstaining is incorrect." ) if task in ["Aggregation", "single-session-user", "single-session-assistant", "multi-session"]: return ( "Check every requested item, count, list member, and named fact. " "Exact counts and required list coverage matter. Extra material " "facts that change the answer should be flagged." ) if task in ["Aggregation + Temporal"]: return ( "Check both the aggregated facts and their time/order associations. " "An answer can name the right item but still be wrong if the timing, " "ordering, before/after relation, or year is wrong." ) if task in ["Temporal Reasoning", "temporal-reasoning"]: return ( "Check the specific time, date, year, sequence, duration, or temporal " "relationship asked for. Accept +/-1 only for day/week/month duration " "counts, not for years, event identity, or ordering." ) if task in ["Knowledge Update", "knowledge-update"]: return ( "Check the current or most recent state. Historical context is fine " "only if the final/current state is clearly correct. Outdated states " "presented as current are incorrect." ) if task == "single-session-preference": return ( "Check whether the response recalls and applies the stated preference. " "Do not require unnecessary verbosity, but contradictions are incorrect." ) return "Check whether the model response satisfies all required facts in the reference answer." def get_anscheck_prompt(task, question, answer, response): """Legacy binary yes/no prompt kept for backward compatibility.""" if task in ["Information Absence"]: template = """You are evaluating whether a model correctly identified that a question is unanswerable given the available personal chat history. Question: {question} Explanation of why it is unanswerable: {answer} Model Response: {response} Evaluation criteria: - CORRECT if the model explicitly states that the information is not available, insufficient, or that the question cannot be answered from the provided context. - INCORRECT if the model fabricates an answer or fails to acknowledge the unanswerable nature of the question. - The model does not need to use the exact word "unanswerable" -- expressing uncertainty or lack of information is sufficient. Briefly explain your reasoning (1-2 sentences), then on the last line write only: yes or no""" elif task in ["Aggregation", "single-session-user", "single-session-assistant", "multi-session"]: template = """You are evaluating whether a model correctly answered a question that requires aggregating specific facts from a user's personal chat history. Question: {question} Reference Answer: {answer} Model Response: {response} Evaluation criteria: - CORRECT if the response identifies all key items, facts, or counts present in the reference answer, even if phrased differently or with added context. - INCORRECT if the response: - States a wrong count (e.g., says "5" when the answer is "3") - Omits one or more key items/facts listed in the reference - Lists mostly wrong items even if the count is right - Partial answers that cover only a subset of required items are INCORRECT. - Verbose responses are acceptable as long as all reference items are present within them. - If the response contains correct items but also lists additional plausible-sounding but unverified items beyond the reference, this does NOT make it incorrect -- evaluate only whether the reference items are covered. - Semantic equivalence counts as correct (e.g., "RSI" = "Repetitive Strain Injury"). Briefly explain your reasoning (2-3 sentences), then on the last line write only: yes or no""" elif task in ["Aggregation + Temporal"]: template = """You are evaluating whether a model correctly answered a question that requires both aggregating facts AND reasoning about their temporal order or time associations from a user's personal chat history. Question: {question} Reference Answer: {answer} Model Response: {response} Evaluation criteria: - CORRECT if the response captures both: (a) all key events/facts listed in the reference, and (b) their correct temporal associations (ordering, time periods, or "before/after" relationships). - INCORRECT if the response: - Omits one or more key events or facts from the reference - Gets the temporal ordering or time associations wrong - Captures only the content without the temporal aspects, or vice versa - Responses that describe the correct progression/sequence in different words are acceptable. - Partial answers covering only some events or ignoring time aspects are INCORRECT. - Minor wording differences or additional explanatory context are acceptable. Briefly explain your reasoning (2-3 sentences), then on the last line write only: yes or no""" elif task in ["Temporal Reasoning", "temporal-reasoning"]: template = """You are evaluating whether a model correctly answered a question about temporal relationships in a user's personal chat history. Question: {question} Reference Answer: {answer} Model Response: {response} Evaluation criteria: - CORRECT if the response correctly identifies the specific time, date, year, sequence, or temporal relationship asked about. - INCORRECT if the response states a wrong year, wrong sequence, wrong temporal relationship, or misidentifies which event came first/last. - Off-by-one tolerance: if the question asks for a count of days, weeks, or months, accept answers that differ by +/-1. This tolerance does NOT apply to years or to identifying specific events/artifacts. - Responses that correctly identify the fact but with verbose context are acceptable. - If the response hedges but still states the correct answer, it is correct. Briefly explain your reasoning (2-3 sentences), then on the last line write only: yes or no""" elif task in ["Knowledge Update", "knowledge-update"]: template = """You are evaluating whether a model correctly answered a question about the most recent or current state of something that changed over time in a user's personal chat history. Question: {question} Reference Answer: {answer} Model Response: {response} Evaluation criteria: - CORRECT if the response correctly identifies the most recent/current state as described in the reference answer. - The response may include earlier historical states as context -- this is acceptable as long as the current/final state is correctly identified and clearly stated. - INCORRECT if the response: - States an outdated/superseded state as the current one - Omits the current state entirely - Correctly describes history but draws the wrong conclusion about what the current state is - Semantic equivalence counts (e.g., "flexitarian" and "semi-vegetarian diet with occasional meat" are equivalent if contextually clear). Briefly explain your reasoning (2-3 sentences), then on the last line write only: yes or no""" elif task == "single-session-preference": template = """You are evaluating whether a model correctly answered a personalized question based on a user's stated preferences from their chat history. Question: {question} Reference Rubric: {answer} Model Response: {response} Evaluation criteria: - CORRECT if the response recalls and applies the user's personal preferences correctly, even if not covering every point in the rubric. - INCORRECT if the response ignores, contradicts, or misremembers the user's preferences. Briefly explain your reasoning (1-2 sentences), then on the last line write only: yes or no""" else: template = """You are evaluating whether a model's response correctly answers a question based on a user's personal chat history. Question: {question} Reference Answer: {answer} Model Response: {response} Is the response correct? It is correct if it contains all key information from the reference answer, even if phrased differently. Briefly explain your reasoning (1-2 sentences), then on the last line write only: yes or no""" return template.format(question=question, answer=answer, response=response) def build_rubric_prompt(task, question, answer): return f"""You are creating an atomic grading rubric for an open-ended QA benchmark. Question type: {task} Question: {question} Reference answer: {answer} Question-type guidance: {question_type_guidance(task)} Decompose the reference answer into the smallest independently checkable requirements needed to answer the question. Rules: - Each atom should be a single required answer unit: an entity, count, date/year, order relation, current-state conclusion, or abstention requirement. - If an entity and its temporal relation are inseparable for correctness, keep them in the same atom. - For list/count questions, include an atom for the exact count when the question asks "how many", and atoms for each required listed item when the item identities matter. - For Information Absence, usually use one atom requiring the response to clearly state that the information is unavailable/insufficient/not discussed, and add a strict note that concrete fabricated answers are wrong. - Do not include supporting evidence requirements or session IDs unless the question explicitly asks for them. - Weights should normally be 1.0. Use a higher weight only when one atom is clearly the main answer and other atoms are minor. Return JSON only with this schema: {{ "required_atoms": [ {{ "id": "a1", "requirement": "short, specific grading requirement", "weight": 1.0 }} ], "strict_notes": ["short note about exactness, ordering, abstention, or hallucination handling"] }}""" def build_atomic_eval_prompt(task, question, answer, response, rubric): rubric_str = json.dumps( { "required_atoms": rubric["required_atoms"], "strict_notes": rubric.get("strict_notes", []), }, ensure_ascii=False, indent=2, ) return f"""You are an LLM-as-a-judge evaluating one model response against a reference answer. Question type: {task} Question: {question} Reference answer: {answer} Model response: {response} Atomic grading rubric: {rubric_str} Question-type guidance: {question_type_guidance(task)} Judge each atom independently. Atom labels: - correct: the response fully satisfies this atom, allowing semantic paraphrase. - partially_correct: the response gets the main idea but is incomplete or slightly underspecified. Use this sparingly. Do not use it for wrong counts, wrong years, wrong named entities, wrong ordering, or a concrete answer to an Information Absence question. - missing: the response does not address this atom. - incorrect: the response contradicts this atom or gives the wrong count/entity/date/order/current state. Also identify unsupported_or_contradictory material: - severity "material": extra answer content that changes the final answer, adds extra items to an exact list/count, gives an outdated current state, fabricates a concrete answer for Information Absence, or contradicts any atom. - severity "minor": harmless context or extra explanation that does not change the answer. Return JSON only with this schema: {{ "atom_judgments": [ {{ "id": "a1", "label": "correct|partially_correct|missing|incorrect", "rationale": "brief reason" }} ], "unsupported_or_contradictory": [ {{ "text": "extra or contradictory claim", "severity": "minor|material", "rationale": "brief reason" }} ], "absence_mismatch": false, "overall_rationale": "one or two sentence summary" }}""" def llm_call( deployment_name: str, api_version: str, _prompt: str, debug: bool = False, vllm: bool = False, tritonai: bool = False, nvidia: bool = False, ): if nvidia: client = OpenAI( api_key=os.getenv("NV_API_KEY"), base_url="https://inference-api.nvidia.com", ) while True: try: return client.chat.completions.create( model=deployment_name, messages=[{"role": "user", "content": _prompt}], ) except Exception as e: st = _retryable_status(e) if st in (429, 500, 503, 403): print(f"[WARN] HTTP {st} from NVIDIA API; sleeping 60s then retrying...", flush=True) time.sleep(60) continue print("One exception captured", repr(e), flush=True) raise if tritonai: client = OpenAI( api_key=os.getenv("TRITONAI_API_KEY"), base_url=TRITONAI_BASE_URL, ) while True: try: return client.chat.completions.create( model=deployment_name, messages=[{"role": "user", "content": _prompt}], ) except Exception as e: st = _retryable_status(e) if st in (429, 500, 503, 403): print(f"[WARN] HTTP {st} from LiteLLM proxy; sleeping 60s then retrying...", flush=True) time.sleep(60) continue print("One exception captured", repr(e), flush=True) raise if deployment_name.startswith("claude-"): import anthropic client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) while True: try: msg = client.messages.create( model=deployment_name, max_tokens=1024, messages=[{"role": "user", "content": _prompt}], ) class _Choice: class _Msg: def __init__(self, text): self.content = text def __init__(self, text): self.message = self._Msg(text) class _Completion: def __init__(self, text): self.choices = [_Choice(text)] return _Completion(msg.content[0].text) except Exception as e: st = _retryable_status(e) if st in (429, 500, 503, 403): print(f"[WARN] HTTP {st} from Anthropic; sleeping 60s then retrying...", flush=True) time.sleep(60) continue print("One exception captured", repr(e), flush=True) raise if vllm: client = OpenAI( base_url=os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1"), api_key=os.getenv("VLLM_API_KEY", "EMPTY"), ) elif debug: client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) else: client = AzureOpenAI( azure_endpoint=endpoint, azure_ad_token_provider=credential, api_version=api_version, ) kwargs = { "model": deployment_name, "messages": [{"role": "system", "content": _prompt}], } while True: try: return client.chat.completions.create(**kwargs) except Exception as e: st = _retryable_status(e) if st in (429, 500, 503, 403): print(f"[WARN] HTTP {st} from LLM; sleeping 120s then retrying...", flush=True) time.sleep(120) continue print("One exception captured", repr(e), flush=True) raise def call_json_llm(prompt, deployment_name, api_version, args, max_retries=3): last_error = None for attempt in range(max_retries): completion = llm_call( deployment_name, api_version, prompt, debug=args.debug, vllm=args.vllm, tritonai=args.tritonai, nvidia=args.nvidia, ) content = completion.choices[0].message.content.strip() try: return parse_json_object(content), content except Exception as e: last_error = e if attempt < max_retries - 1: print(f"[WARN] Failed to parse judge JSON; retrying ({attempt + 1}/{max_retries})", flush=True) time.sleep(2) raise ValueError(f"Failed to parse JSON response from judge: {last_error}") def fallback_rubric(qid, task, question, answer): return { "question_id": qid, "question_type": task, "question": question, "reference_answer": answer, "required_atoms": [ { "id": "a1", "requirement": f"Response must correctly answer the question according to the reference answer: {answer}", "weight": 1.0, } ], "strict_notes": ["Fallback single-atom rubric produced because the generated rubric was invalid."], "prompt_version": ATOMIC_PROMPT_VERSION, } def normalize_rubric(qid, task, question, answer, parsed): atoms = parsed.get("required_atoms", []) if isinstance(parsed, dict) else [] norm_atoms = [] for idx, atom in enumerate(atoms, start=1): if not isinstance(atom, dict): continue requirement = str(atom.get("requirement", "")).strip() if not requirement: continue atom_id = str(atom.get("id", f"a{idx}")).strip() or f"a{idx}" try: weight = float(atom.get("weight", 1.0)) except (TypeError, ValueError): weight = 1.0 if weight <= 0: weight = 1.0 norm_atoms.append({"id": atom_id, "requirement": requirement, "weight": weight}) if not norm_atoms: return fallback_rubric(qid, task, question, answer) strict_notes = parsed.get("strict_notes", []) if not isinstance(strict_notes, list): strict_notes = [str(strict_notes)] return { "question_id": qid, "question_type": task, "question": question, "reference_answer": answer, "required_atoms": norm_atoms, "strict_notes": [str(x) for x in strict_notes], "prompt_version": ATOMIC_PROMPT_VERSION, } def get_or_build_rubric(qdata, rubric_data, rubric_file, deployment_name, api_version, args): qid = qdata["question_id"] existing = rubric_data["rubrics"].get(qid) if ( existing and existing.get("prompt_version") == ATOMIC_PROMPT_VERSION and existing.get("required_atoms") and not args.force_rebuild_rubric ): return existing task = qdata["question_type"] prompt = build_rubric_prompt(task, qdata["question"], qdata["answer"]) try: parsed, raw = call_json_llm(prompt, deployment_name, api_version, args) rubric = normalize_rubric(qid, task, qdata["question"], qdata["answer"], parsed) rubric["rubric_raw_response"] = raw except Exception as e: print(f"[WARN] Falling back to single-atom rubric for {qid}: {e}", flush=True) rubric = fallback_rubric(qid, task, qdata["question"], qdata["answer"]) rubric_data["rubrics"][qid] = rubric write_json(rubric_file, rubric_data) return rubric def compute_atomic_scores(rubric, parsed): atoms = rubric.get("required_atoms", []) judgments_by_id = {} raw_judgments = parsed.get("atom_judgments", []) if isinstance(parsed, dict) else [] if isinstance(raw_judgments, list): for judgment in raw_judgments: if not isinstance(judgment, dict): continue atom_id = str(judgment.get("id", "")).strip() label = str(judgment.get("label", "")).strip() if label not in ATOM_SCORES: label = "incorrect" judgments_by_id[atom_id] = { "id": atom_id, "label": label, "score": ATOM_SCORES[label], "rationale": str(judgment.get("rationale", "")), } norm_judgments = [] weighted_score = 0.0 total_weight = 0.0 for atom in atoms: atom_id = atom["id"] weight = float(atom.get("weight", 1.0)) judgment = judgments_by_id.get( atom_id, {"id": atom_id, "label": "missing", "score": 0.0, "rationale": "No judgment returned."}, ) judgment["requirement"] = atom["requirement"] judgment["weight"] = weight norm_judgments.append(judgment) weighted_score += judgment["score"] * weight total_weight += weight extras = parsed.get("unsupported_or_contradictory", []) if isinstance(parsed, dict) else [] if not isinstance(extras, list): extras = [] material_extras = [ x for x in extras if isinstance(x, dict) and str(x.get("severity", "")).strip() == "material" ] absence_mismatch = bool(parsed.get("absence_mismatch", False)) if isinstance(parsed, dict) else False strict_label = ( bool(norm_judgments) and all(j["label"] == "correct" for j in norm_judgments) and not material_extras and not absence_mismatch ) partial_score = weighted_score / total_weight if total_weight > 0 else 0.0 if absence_mismatch: partial_score = 0.0 elif material_extras and partial_score > 0.8: partial_score = 0.8 return { "strict_label": strict_label, "partial_score": round(partial_score, 4), "atom_judgments": norm_judgments, "unsupported_or_contradictory": extras, "absence_mismatch": absence_mismatch, "overall_rationale": str(parsed.get("overall_rationale", "")) if isinstance(parsed, dict) else "", } def judge_atomic(qdata, hypothesis, rubric, deployment_name, api_version, args): prompt = build_atomic_eval_prompt( qdata["question_type"], qdata["question"], qdata["answer"], hypothesis, rubric, ) parsed, raw = call_json_llm(prompt, deployment_name, api_version, args) scores = compute_atomic_scores(rubric, parsed) return { "model": args.eval_model_name, "prompt_version": ATOMIC_PROMPT_VERSION, "eval_mode": args.eval_mode, "strict_label": scores["strict_label"], "partial_score": scores["partial_score"], "required_atoms": rubric["required_atoms"], "atom_judgments": scores["atom_judgments"], "unsupported_or_contradictory": scores["unsupported_or_contradictory"], "absence_mismatch": scores["absence_mismatch"], "overall_rationale": scores["overall_rationale"], "raw_response": raw, } def should_skip_existing(existing_row, eval_mode): if eval_mode == "legacy": return "autoeval_label" in existing_row atomic = existing_row.get("autoeval_atomic") return bool(atomic and atomic.get("prompt_version") == ATOMIC_PROMPT_VERSION) def safe_mean(values): if not values: return float("nan") return float(np.mean(values)) def print_legacy_summary(logs, qtype2acc): labels = [1 if x["autoeval_label"]["label"] else 0 for x in logs if "autoeval_label" in x] print("Accuracy:", round(safe_mean(labels), 4)) for k, v in sorted(qtype2acc.items()): print("\t{}: {} ({})".format(k, round(safe_mean(v), 4), len(v))) def print_atomic_summary(logs, qtype2strict, qtype2partial, eval_mode): strict_values = [ 1 if x["autoeval_atomic"]["strict_label"] else 0 for x in logs if "autoeval_atomic" in x ] partial_values = [ float(x["autoeval_atomic"]["partial_score"]) for x in logs if "autoeval_atomic" in x ] if eval_mode in ("strict", "both"): print("Strict Accuracy:", round(safe_mean(strict_values), 4)) for k, v in sorted(qtype2strict.items()): print("\t{}: {} ({})".format(k, round(safe_mean(v), 4), len(v))) if eval_mode in ("partial", "both"): print("Partial Score:", round(safe_mean(partial_values), 4)) for k, v in sorted(qtype2partial.items()): print("\t{}: {} ({})".format(k, round(safe_mean(v), 4), len(v))) def main(): parser = argparse.ArgumentParser() parser.add_argument("--hyp_file", type=str, default=None) parser.add_argument("--ref_file", type=str, required=True) parser.add_argument("--eval_model_name", type=str, required=True) parser.add_argument( "--eval_mode", type=str, default="both", choices=["legacy", "strict", "partial", "both"], help="legacy uses the old yes/no judge; strict/partial/both use atomic JSON judging.", ) parser.add_argument("--rubric_file", type=str, default=None) parser.add_argument("--build_rubric_only", action="store_true", default=False) parser.add_argument("--force_rebuild_rubric", action="store_true", default=False) parser.add_argument("--result_file", type=str, default=None) parser.add_argument("--debug", action="store_true", default=False) parser.add_argument("--vllm", action="store_true", default=False) parser.add_argument("--tritonai", action="store_true", default=False, help="Use OpenAI-compatible LiteLLM proxy (set TRITONAI_API_KEY env var)") parser.add_argument("--nvidia", action="store_true", default=False, help="Use NVIDIA inference API (set NV_API_KEY env var)") parser.add_argument("--verbose", action=argparse.BooleanOptionalAction, default=True) args = parser.parse_args() if not args.build_rubric_only and not args.hyp_file: parser.error("--hyp_file is required unless --build_rubric_only is set") metric_model = args.eval_model_name deployment_name, api_version = model_zoo[metric_model] references = read_json_or_jsonl(args.ref_file) qid2qdata = {entry["question_id"]: entry for entry in references} qid2qtype = {entry["question_id"]: entry["question_type"] for entry in references} qtypes = set(qid2qtype.values()) rubric_data = None rubric_file = args.rubric_file or default_rubric_file(args.ref_file) if args.eval_mode != "legacy" or args.build_rubric_only: rubric_data = load_rubric_file(rubric_file, args.ref_file) if args.build_rubric_only: for entry in tqdm(references, desc="building rubrics"): get_or_build_rubric(entry, rubric_data, rubric_file, deployment_name, api_version, args) print(f"Saved rubric file to {rubric_file}") return result_file = args.result_file or default_result_file(args.hyp_file, metric_model, args.eval_mode) existing = read_existing_jsonl(result_file) hypotheses = read_json_or_jsonl(args.hyp_file) qtype2acc = {t: [] for t in qtypes} qtype2strict = {t: [] for t in qtypes} qtype2partial = {t: [] for t in qtypes} logs = [] for entry in tqdm(hypotheses): qid = entry.get("question_id") if qid not in qid2qtype: if qid is not None: print(f"Warning: skipping {qid} as it is not in reference data.") continue if qid in existing and should_skip_existing(existing[qid], args.eval_mode): existing_row = existing[qid] logs.append(existing_row) qtype = qid2qtype[qid] if args.eval_mode == "legacy": label = existing_row["autoeval_label"]["label"] qtype2acc[qtype].append(1 if label else 0) else: atomic = existing_row["autoeval_atomic"] qtype2strict[qtype].append(1 if atomic["strict_label"] else 0) qtype2partial[qtype].append(float(atomic["partial_score"])) continue qdata = qid2qdata[qid] qtype = qdata["question_type"] hyp = entry["hypothesis"] if args.eval_mode == "legacy": prompt = get_anscheck_prompt(qtype, qdata["question"], qdata["answer"], hyp) completion = llm_call( deployment_name, api_version, prompt, debug=args.debug, vllm=args.vllm, tritonai=args.tritonai, nvidia=args.nvidia, ) eval_response = completion.choices[0].message.content.strip() last_line = next((l.strip().lower() for l in reversed(eval_response.splitlines()) if l.strip()), "") label = last_line == "yes" or last_line.startswith("yes") row = dict(entry) row["autoeval_label"] = { "model": metric_model, "prompt_version": LEGACY_PROMPT_VERSION, "label": label, "raw_response": eval_response, } logs.append(row) qtype2acc[qtype].append(1 if label else 0) if args.verbose: print(json.dumps({ "question": qdata["question"], "answer": qdata["answer"], "hypothesis": hyp, "autoeval_label": label, }, indent=4), flush=True) append_jsonl(result_file, row) continue rubric = get_or_build_rubric(qdata, rubric_data, rubric_file, deployment_name, api_version, args) try: atomic_eval = judge_atomic(qdata, hyp, rubric, deployment_name, api_version, args) except ValueError as _judge_err: print(f"[WARN] judge_atomic failed for {qdata['question_id']}, writing zero score: {_judge_err}", flush=True) atoms = rubric.get("required_atoms", []) atomic_eval = { "model": deployment_name, "prompt_version": ATOMIC_PROMPT_VERSION, "eval_mode": args.eval_mode, "strict_label": False, "partial_score": 0.0, "required_atoms": atoms, "atom_judgments": [{"id": a["id"], "label": "error", "score": 0.0, "rationale": "judge parse error", "requirement": a.get("requirement", ""), "weight": a.get("weight", 1.0)} for a in atoms], "unsupported_or_contradictory": [], "absence_mismatch": False, "overall_rationale": f"Skipped: judge JSON parse error ({_judge_err})", } row = dict(entry) row["autoeval_atomic"] = atomic_eval logs.append(row) qtype2strict[qtype].append(1 if atomic_eval["strict_label"] else 0) qtype2partial[qtype].append(float(atomic_eval["partial_score"])) if args.verbose: print(json.dumps({ "question": qdata["question"], "answer": qdata["answer"], "hypothesis": hyp, "strict_label": atomic_eval["strict_label"], "partial_score": atomic_eval["partial_score"], "atom_judgments": atomic_eval["atom_judgments"], }, indent=4), flush=True) append_jsonl(result_file, row) if args.eval_mode == "legacy": print_legacy_summary(logs, qtype2acc) else: print_atomic_summary(logs, qtype2strict, qtype2partial, args.eval_mode) print(f"Rubric file: {rubric_file}") print("Saved to", result_file) if __name__ == "__main__": main()