| import argparse |
| import json |
| import os |
| import re |
| import time |
|
|
| import numpy as np |
| from openai import OpenAI |
| from tqdm import tqdm |
|
|
| try: |
| from openai import AzureOpenAI |
| from azure.identity import ( |
| AzureCliCredential, |
| ChainedTokenCredential, |
| ManagedIdentityCredential, |
| get_bearer_token_provider, |
| ) |
|
|
| AZURE_OAUTH_SCOPE = os.environ.get("AZURE_OAUTH_SCOPE", "") |
| if AZURE_OAUTH_SCOPE: |
| credential = get_bearer_token_provider( |
| ChainedTokenCredential( |
| AzureCliCredential(), |
| ManagedIdentityCredential(), |
| ), |
| AZURE_OAUTH_SCOPE, |
| ) |
| else: |
| credential = None |
| except ImportError: |
| AzureOpenAI = None |
| credential = None |
|
|
| from model_zoo import model_zoo |
|
|
|
|
| |
| endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT", "") |
| |
| TRITONAI_BASE_URL = os.environ.get("LITELLM_BASE_URL", "") |
|
|
| ATOMIC_PROMPT_VERSION = "atomic-v1" |
| LEGACY_PROMPT_VERSION = "binary-v0" |
| ATOM_SCORES = { |
| "correct": 1.0, |
| "partially_correct": 0.5, |
| "missing": 0.0, |
| "incorrect": 0.0, |
| } |
|
|
|
|
| def _retryable_status(e) -> "int | None": |
| status = getattr(e, "status_code", None) or getattr(e, "http_status", None) |
| if status in (429, 500, 503, 403): |
| return status |
| resp = getattr(e, "response", None) |
| if resp is not None and getattr(resp, "status_code", None) in (429, 500, 503): |
| return resp.status_code |
| msg = str(e).lower() |
| if "429" in msg or "rate limit" in msg: |
| return 429 |
| if "500" in msg or "internal server error" in msg: |
| return 500 |
| if "503" in msg or "api configuration unavailable" in msg: |
| return 503 |
| return None |
|
|
|
|
| def parse_json_object(text): |
| text = (text or "").strip() |
| if text.startswith("```"): |
| text = re.sub(r"^```(?:json)?", "", text).strip() |
| text = re.sub(r"```$", "", text).strip() |
| try: |
| return json.loads(text) |
| except json.JSONDecodeError: |
| start = text.find("{") |
| end = text.rfind("}") + 1 |
| if start >= 0 and end > start: |
| return json.loads(text[start:end]) |
| raise |
|
|
|
|
| def sanitize_model_name(name): |
| return re.sub(r"[^A-Za-z0-9_.-]+", "_", name) |
|
|
|
|
| def read_json_or_jsonl(path): |
| try: |
| with open(path, "r", encoding="utf-8") as f: |
| return json.load(f) |
| except json.JSONDecodeError: |
| with open(path, "r", encoding="utf-8") as f: |
| return [json.loads(line) for line in f if line.strip()] |
|
|
|
|
| def read_existing_jsonl(path): |
| if not path or not os.path.exists(path): |
| return {} |
| rows = {} |
| with open(path, "r", encoding="utf-8") as f: |
| for line in f: |
| if not line.strip(): |
| continue |
| obj = json.loads(line) |
| if "question_id" in obj: |
| rows[obj["question_id"]] = obj |
| return rows |
|
|
|
|
| def write_json(path, obj): |
| tmp_path = path + ".tmp" |
| with open(tmp_path, "w", encoding="utf-8") as f: |
| json.dump(obj, f, ensure_ascii=False, indent=2) |
| f.write("\n") |
| os.replace(tmp_path, path) |
|
|
|
|
| def append_jsonl(path, row): |
| with open(path, "a", encoding="utf-8") as f: |
| print(json.dumps(row, ensure_ascii=False), file=f, flush=True) |
|
|
|
|
| def default_result_file(hyp_file, metric_model, eval_mode): |
| if eval_mode == "legacy": |
| return f"{hyp_file}.eval-results-{metric_model}" |
| model_tag = sanitize_model_name(metric_model) |
| return f"{hyp_file}.eval-results-{model_tag}-{ATOMIC_PROMPT_VERSION}-{eval_mode}" |
|
|
|
|
| def default_rubric_file(ref_file): |
| return f"{ref_file}.{ATOMIC_PROMPT_VERSION}.rubric.json" |
|
|
|
|
| def load_rubric_file(path, ref_file): |
| if not path or not os.path.exists(path): |
| return { |
| "prompt_version": ATOMIC_PROMPT_VERSION, |
| "source_ref_file": ref_file, |
| "rubrics": {}, |
| } |
| with open(path, "r", encoding="utf-8") as f: |
| data = json.load(f) |
| if "rubrics" in data: |
| data.setdefault("prompt_version", ATOMIC_PROMPT_VERSION) |
| data.setdefault("source_ref_file", ref_file) |
| return data |
| return { |
| "prompt_version": ATOMIC_PROMPT_VERSION, |
| "source_ref_file": ref_file, |
| "rubrics": data, |
| } |
|
|
|
|
| def question_type_guidance(task): |
| if task in ["Information Absence"]: |
| return ( |
| "The correct answer is that the information is unavailable, absent, " |
| "not yet known, or not supported. A response that gives a concrete " |
| "answer instead of abstaining is incorrect." |
| ) |
| if task in ["Aggregation", "single-session-user", "single-session-assistant", "multi-session"]: |
| return ( |
| "Check every requested item, count, list member, and named fact. " |
| "Exact counts and required list coverage matter. Extra material " |
| "facts that change the answer should be flagged." |
| ) |
| if task in ["Aggregation + Temporal"]: |
| return ( |
| "Check both the aggregated facts and their time/order associations. " |
| "An answer can name the right item but still be wrong if the timing, " |
| "ordering, before/after relation, or year is wrong." |
| ) |
| if task in ["Temporal Reasoning", "temporal-reasoning"]: |
| return ( |
| "Check the specific time, date, year, sequence, duration, or temporal " |
| "relationship asked for. Accept +/-1 only for day/week/month duration " |
| "counts, not for years, event identity, or ordering." |
| ) |
| if task in ["Knowledge Update", "knowledge-update"]: |
| return ( |
| "Check the current or most recent state. Historical context is fine " |
| "only if the final/current state is clearly correct. Outdated states " |
| "presented as current are incorrect." |
| ) |
| if task == "single-session-preference": |
| return ( |
| "Check whether the response recalls and applies the stated preference. " |
| "Do not require unnecessary verbosity, but contradictions are incorrect." |
| ) |
| return "Check whether the model response satisfies all required facts in the reference answer." |
|
|
|
|
| def get_anscheck_prompt(task, question, answer, response): |
| """Legacy binary yes/no prompt kept for backward compatibility.""" |
| if task in ["Information Absence"]: |
| template = """You are evaluating whether a model correctly identified that a question is unanswerable given the available personal chat history. |
| |
| Question: {question} |
| |
| Explanation of why it is unanswerable: {answer} |
| |
| Model Response: {response} |
| |
| Evaluation criteria: |
| - CORRECT if the model explicitly states that the information is not available, insufficient, or that the question cannot be answered from the provided context. |
| - INCORRECT if the model fabricates an answer or fails to acknowledge the unanswerable nature of the question. |
| - The model does not need to use the exact word "unanswerable" -- expressing uncertainty or lack of information is sufficient. |
| |
| Briefly explain your reasoning (1-2 sentences), then on the last line write only: yes or no""" |
| elif task in ["Aggregation", "single-session-user", "single-session-assistant", "multi-session"]: |
| template = """You are evaluating whether a model correctly answered a question that requires aggregating specific facts from a user's personal chat history. |
| |
| Question: {question} |
| |
| Reference Answer: {answer} |
| |
| Model Response: {response} |
| |
| Evaluation criteria: |
| - CORRECT if the response identifies all key items, facts, or counts present in the reference answer, even if phrased differently or with added context. |
| - INCORRECT if the response: |
| - States a wrong count (e.g., says "5" when the answer is "3") |
| - Omits one or more key items/facts listed in the reference |
| - Lists mostly wrong items even if the count is right |
| - Partial answers that cover only a subset of required items are INCORRECT. |
| - Verbose responses are acceptable as long as all reference items are present within them. |
| - If the response contains correct items but also lists additional plausible-sounding but unverified items beyond the reference, this does NOT make it incorrect -- evaluate only whether the reference items are covered. |
| - Semantic equivalence counts as correct (e.g., "RSI" = "Repetitive Strain Injury"). |
| |
| Briefly explain your reasoning (2-3 sentences), then on the last line write only: yes or no""" |
| elif task in ["Aggregation + Temporal"]: |
| template = """You are evaluating whether a model correctly answered a question that requires both aggregating facts AND reasoning about their temporal order or time associations from a user's personal chat history. |
| |
| Question: {question} |
| |
| Reference Answer: {answer} |
| |
| Model Response: {response} |
| |
| Evaluation criteria: |
| - CORRECT if the response captures both: |
| (a) all key events/facts listed in the reference, and |
| (b) their correct temporal associations (ordering, time periods, or "before/after" relationships). |
| - INCORRECT if the response: |
| - Omits one or more key events or facts from the reference |
| - Gets the temporal ordering or time associations wrong |
| - Captures only the content without the temporal aspects, or vice versa |
| - Responses that describe the correct progression/sequence in different words are acceptable. |
| - Partial answers covering only some events or ignoring time aspects are INCORRECT. |
| - Minor wording differences or additional explanatory context are acceptable. |
| |
| Briefly explain your reasoning (2-3 sentences), then on the last line write only: yes or no""" |
| elif task in ["Temporal Reasoning", "temporal-reasoning"]: |
| template = """You are evaluating whether a model correctly answered a question about temporal relationships in a user's personal chat history. |
| |
| Question: {question} |
| |
| Reference Answer: {answer} |
| |
| Model Response: {response} |
| |
| Evaluation criteria: |
| - CORRECT if the response correctly identifies the specific time, date, year, sequence, or temporal relationship asked about. |
| - INCORRECT if the response states a wrong year, wrong sequence, wrong temporal relationship, or misidentifies which event came first/last. |
| - Off-by-one tolerance: if the question asks for a count of days, weeks, or months, accept answers that differ by +/-1. This tolerance does NOT apply to years or to identifying specific events/artifacts. |
| - Responses that correctly identify the fact but with verbose context are acceptable. |
| - If the response hedges but still states the correct answer, it is correct. |
| |
| Briefly explain your reasoning (2-3 sentences), then on the last line write only: yes or no""" |
| elif task in ["Knowledge Update", "knowledge-update"]: |
| template = """You are evaluating whether a model correctly answered a question about the most recent or current state of something that changed over time in a user's personal chat history. |
| |
| Question: {question} |
| |
| Reference Answer: {answer} |
| |
| Model Response: {response} |
| |
| Evaluation criteria: |
| - CORRECT if the response correctly identifies the most recent/current state as described in the reference answer. |
| - The response may include earlier historical states as context -- this is acceptable as long as the current/final state is correctly identified and clearly stated. |
| - INCORRECT if the response: |
| - States an outdated/superseded state as the current one |
| - Omits the current state entirely |
| - Correctly describes history but draws the wrong conclusion about what the current state is |
| - Semantic equivalence counts (e.g., "flexitarian" and "semi-vegetarian diet with occasional meat" are equivalent if contextually clear). |
| |
| Briefly explain your reasoning (2-3 sentences), then on the last line write only: yes or no""" |
| elif task == "single-session-preference": |
| template = """You are evaluating whether a model correctly answered a personalized question based on a user's stated preferences from their chat history. |
| |
| Question: {question} |
| |
| Reference Rubric: {answer} |
| |
| Model Response: {response} |
| |
| Evaluation criteria: |
| - CORRECT if the response recalls and applies the user's personal preferences correctly, even if not covering every point in the rubric. |
| - INCORRECT if the response ignores, contradicts, or misremembers the user's preferences. |
| |
| Briefly explain your reasoning (1-2 sentences), then on the last line write only: yes or no""" |
| else: |
| template = """You are evaluating whether a model's response correctly answers a question based on a user's personal chat history. |
| |
| Question: {question} |
| |
| Reference Answer: {answer} |
| |
| Model Response: {response} |
| |
| Is the response correct? It is correct if it contains all key information from the reference answer, even if phrased differently. |
| |
| Briefly explain your reasoning (1-2 sentences), then on the last line write only: yes or no""" |
| return template.format(question=question, answer=answer, response=response) |
|
|
|
|
| def build_rubric_prompt(task, question, answer): |
| return f"""You are creating an atomic grading rubric for an open-ended QA benchmark. |
| |
| Question type: {task} |
| Question: |
| {question} |
| |
| Reference answer: |
| {answer} |
| |
| Question-type guidance: |
| {question_type_guidance(task)} |
| |
| Decompose the reference answer into the smallest independently checkable requirements needed to answer the question. |
| |
| Rules: |
| - Each atom should be a single required answer unit: an entity, count, date/year, order relation, current-state conclusion, or abstention requirement. |
| - If an entity and its temporal relation are inseparable for correctness, keep them in the same atom. |
| - For list/count questions, include an atom for the exact count when the question asks "how many", and atoms for each required listed item when the item identities matter. |
| - For Information Absence, usually use one atom requiring the response to clearly state that the information is unavailable/insufficient/not discussed, and add a strict note that concrete fabricated answers are wrong. |
| - Do not include supporting evidence requirements or session IDs unless the question explicitly asks for them. |
| - Weights should normally be 1.0. Use a higher weight only when one atom is clearly the main answer and other atoms are minor. |
| |
| Return JSON only with this schema: |
| {{ |
| "required_atoms": [ |
| {{ |
| "id": "a1", |
| "requirement": "short, specific grading requirement", |
| "weight": 1.0 |
| }} |
| ], |
| "strict_notes": ["short note about exactness, ordering, abstention, or hallucination handling"] |
| }}""" |
|
|
|
|
| def build_atomic_eval_prompt(task, question, answer, response, rubric): |
| rubric_str = json.dumps( |
| { |
| "required_atoms": rubric["required_atoms"], |
| "strict_notes": rubric.get("strict_notes", []), |
| }, |
| ensure_ascii=False, |
| indent=2, |
| ) |
| return f"""You are an LLM-as-a-judge evaluating one model response against a reference answer. |
| |
| Question type: {task} |
| Question: |
| {question} |
| |
| Reference answer: |
| {answer} |
| |
| Model response: |
| {response} |
| |
| Atomic grading rubric: |
| {rubric_str} |
| |
| Question-type guidance: |
| {question_type_guidance(task)} |
| |
| Judge each atom independently. |
| |
| Atom labels: |
| - correct: the response fully satisfies this atom, allowing semantic paraphrase. |
| - partially_correct: the response gets the main idea but is incomplete or slightly underspecified. Use this sparingly. Do not use it for wrong counts, wrong years, wrong named entities, wrong ordering, or a concrete answer to an Information Absence question. |
| - missing: the response does not address this atom. |
| - incorrect: the response contradicts this atom or gives the wrong count/entity/date/order/current state. |
| |
| Also identify unsupported_or_contradictory material: |
| - severity "material": extra answer content that changes the final answer, adds extra items to an exact list/count, gives an outdated current state, fabricates a concrete answer for Information Absence, or contradicts any atom. |
| - severity "minor": harmless context or extra explanation that does not change the answer. |
| |
| Return JSON only with this schema: |
| {{ |
| "atom_judgments": [ |
| {{ |
| "id": "a1", |
| "label": "correct|partially_correct|missing|incorrect", |
| "rationale": "brief reason" |
| }} |
| ], |
| "unsupported_or_contradictory": [ |
| {{ |
| "text": "extra or contradictory claim", |
| "severity": "minor|material", |
| "rationale": "brief reason" |
| }} |
| ], |
| "absence_mismatch": false, |
| "overall_rationale": "one or two sentence summary" |
| }}""" |
|
|
|
|
| def llm_call( |
| deployment_name: str, |
| api_version: str, |
| _prompt: str, |
| debug: bool = False, |
| vllm: bool = False, |
| tritonai: bool = False, |
| nvidia: bool = False, |
| ): |
| if nvidia: |
| client = OpenAI( |
| api_key=os.getenv("NV_API_KEY"), |
| base_url="https://inference-api.nvidia.com", |
| ) |
| while True: |
| try: |
| return client.chat.completions.create( |
| model=deployment_name, |
| messages=[{"role": "user", "content": _prompt}], |
| ) |
| except Exception as e: |
| st = _retryable_status(e) |
| if st in (429, 500, 503, 403): |
| print(f"[WARN] HTTP {st} from NVIDIA API; sleeping 60s then retrying...", flush=True) |
| time.sleep(60) |
| continue |
| print("One exception captured", repr(e), flush=True) |
| raise |
|
|
| if tritonai: |
| client = OpenAI( |
| api_key=os.getenv("TRITONAI_API_KEY"), |
| base_url=TRITONAI_BASE_URL, |
| ) |
| while True: |
| try: |
| return client.chat.completions.create( |
| model=deployment_name, |
| messages=[{"role": "user", "content": _prompt}], |
| ) |
| except Exception as e: |
| st = _retryable_status(e) |
| if st in (429, 500, 503, 403): |
| print(f"[WARN] HTTP {st} from LiteLLM proxy; sleeping 60s then retrying...", flush=True) |
| time.sleep(60) |
| continue |
| print("One exception captured", repr(e), flush=True) |
| raise |
|
|
| if deployment_name.startswith("claude-"): |
| import anthropic |
|
|
| client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) |
| while True: |
| try: |
| msg = client.messages.create( |
| model=deployment_name, |
| max_tokens=1024, |
| messages=[{"role": "user", "content": _prompt}], |
| ) |
|
|
| class _Choice: |
| class _Msg: |
| def __init__(self, text): |
| self.content = text |
|
|
| def __init__(self, text): |
| self.message = self._Msg(text) |
|
|
| class _Completion: |
| def __init__(self, text): |
| self.choices = [_Choice(text)] |
|
|
| return _Completion(msg.content[0].text) |
| except Exception as e: |
| st = _retryable_status(e) |
| if st in (429, 500, 503, 403): |
| print(f"[WARN] HTTP {st} from Anthropic; sleeping 60s then retrying...", flush=True) |
| time.sleep(60) |
| continue |
| print("One exception captured", repr(e), flush=True) |
| raise |
|
|
| if vllm: |
| client = OpenAI( |
| base_url=os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1"), |
| api_key=os.getenv("VLLM_API_KEY", "EMPTY"), |
| ) |
| elif debug: |
| client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
| else: |
| client = AzureOpenAI( |
| azure_endpoint=endpoint, |
| azure_ad_token_provider=credential, |
| api_version=api_version, |
| ) |
|
|
| kwargs = { |
| "model": deployment_name, |
| "messages": [{"role": "system", "content": _prompt}], |
| } |
| while True: |
| try: |
| return client.chat.completions.create(**kwargs) |
| except Exception as e: |
| st = _retryable_status(e) |
| if st in (429, 500, 503, 403): |
| print(f"[WARN] HTTP {st} from LLM; sleeping 120s then retrying...", flush=True) |
| time.sleep(120) |
| continue |
| print("One exception captured", repr(e), flush=True) |
| raise |
|
|
|
|
| def call_json_llm(prompt, deployment_name, api_version, args, max_retries=3): |
| last_error = None |
| for attempt in range(max_retries): |
| completion = llm_call( |
| deployment_name, |
| api_version, |
| prompt, |
| debug=args.debug, |
| vllm=args.vllm, |
| tritonai=args.tritonai, |
| nvidia=args.nvidia, |
| ) |
| content = completion.choices[0].message.content.strip() |
| try: |
| return parse_json_object(content), content |
| except Exception as e: |
| last_error = e |
| if attempt < max_retries - 1: |
| print(f"[WARN] Failed to parse judge JSON; retrying ({attempt + 1}/{max_retries})", flush=True) |
| time.sleep(2) |
| raise ValueError(f"Failed to parse JSON response from judge: {last_error}") |
|
|
|
|
| def fallback_rubric(qid, task, question, answer): |
| return { |
| "question_id": qid, |
| "question_type": task, |
| "question": question, |
| "reference_answer": answer, |
| "required_atoms": [ |
| { |
| "id": "a1", |
| "requirement": f"Response must correctly answer the question according to the reference answer: {answer}", |
| "weight": 1.0, |
| } |
| ], |
| "strict_notes": ["Fallback single-atom rubric produced because the generated rubric was invalid."], |
| "prompt_version": ATOMIC_PROMPT_VERSION, |
| } |
|
|
|
|
| def normalize_rubric(qid, task, question, answer, parsed): |
| atoms = parsed.get("required_atoms", []) if isinstance(parsed, dict) else [] |
| norm_atoms = [] |
| for idx, atom in enumerate(atoms, start=1): |
| if not isinstance(atom, dict): |
| continue |
| requirement = str(atom.get("requirement", "")).strip() |
| if not requirement: |
| continue |
| atom_id = str(atom.get("id", f"a{idx}")).strip() or f"a{idx}" |
| try: |
| weight = float(atom.get("weight", 1.0)) |
| except (TypeError, ValueError): |
| weight = 1.0 |
| if weight <= 0: |
| weight = 1.0 |
| norm_atoms.append({"id": atom_id, "requirement": requirement, "weight": weight}) |
| if not norm_atoms: |
| return fallback_rubric(qid, task, question, answer) |
| strict_notes = parsed.get("strict_notes", []) |
| if not isinstance(strict_notes, list): |
| strict_notes = [str(strict_notes)] |
| return { |
| "question_id": qid, |
| "question_type": task, |
| "question": question, |
| "reference_answer": answer, |
| "required_atoms": norm_atoms, |
| "strict_notes": [str(x) for x in strict_notes], |
| "prompt_version": ATOMIC_PROMPT_VERSION, |
| } |
|
|
|
|
| def get_or_build_rubric(qdata, rubric_data, rubric_file, deployment_name, api_version, args): |
| qid = qdata["question_id"] |
| existing = rubric_data["rubrics"].get(qid) |
| if ( |
| existing |
| and existing.get("prompt_version") == ATOMIC_PROMPT_VERSION |
| and existing.get("required_atoms") |
| and not args.force_rebuild_rubric |
| ): |
| return existing |
|
|
| task = qdata["question_type"] |
| prompt = build_rubric_prompt(task, qdata["question"], qdata["answer"]) |
| try: |
| parsed, raw = call_json_llm(prompt, deployment_name, api_version, args) |
| rubric = normalize_rubric(qid, task, qdata["question"], qdata["answer"], parsed) |
| rubric["rubric_raw_response"] = raw |
| except Exception as e: |
| print(f"[WARN] Falling back to single-atom rubric for {qid}: {e}", flush=True) |
| rubric = fallback_rubric(qid, task, qdata["question"], qdata["answer"]) |
| rubric_data["rubrics"][qid] = rubric |
| write_json(rubric_file, rubric_data) |
| return rubric |
|
|
|
|
| def compute_atomic_scores(rubric, parsed): |
| atoms = rubric.get("required_atoms", []) |
| judgments_by_id = {} |
| raw_judgments = parsed.get("atom_judgments", []) if isinstance(parsed, dict) else [] |
| if isinstance(raw_judgments, list): |
| for judgment in raw_judgments: |
| if not isinstance(judgment, dict): |
| continue |
| atom_id = str(judgment.get("id", "")).strip() |
| label = str(judgment.get("label", "")).strip() |
| if label not in ATOM_SCORES: |
| label = "incorrect" |
| judgments_by_id[atom_id] = { |
| "id": atom_id, |
| "label": label, |
| "score": ATOM_SCORES[label], |
| "rationale": str(judgment.get("rationale", "")), |
| } |
|
|
| norm_judgments = [] |
| weighted_score = 0.0 |
| total_weight = 0.0 |
| for atom in atoms: |
| atom_id = atom["id"] |
| weight = float(atom.get("weight", 1.0)) |
| judgment = judgments_by_id.get( |
| atom_id, |
| {"id": atom_id, "label": "missing", "score": 0.0, "rationale": "No judgment returned."}, |
| ) |
| judgment["requirement"] = atom["requirement"] |
| judgment["weight"] = weight |
| norm_judgments.append(judgment) |
| weighted_score += judgment["score"] * weight |
| total_weight += weight |
|
|
| extras = parsed.get("unsupported_or_contradictory", []) if isinstance(parsed, dict) else [] |
| if not isinstance(extras, list): |
| extras = [] |
| material_extras = [ |
| x for x in extras |
| if isinstance(x, dict) and str(x.get("severity", "")).strip() == "material" |
| ] |
| absence_mismatch = bool(parsed.get("absence_mismatch", False)) if isinstance(parsed, dict) else False |
|
|
| strict_label = ( |
| bool(norm_judgments) |
| and all(j["label"] == "correct" for j in norm_judgments) |
| and not material_extras |
| and not absence_mismatch |
| ) |
| partial_score = weighted_score / total_weight if total_weight > 0 else 0.0 |
| if absence_mismatch: |
| partial_score = 0.0 |
| elif material_extras and partial_score > 0.8: |
| partial_score = 0.8 |
|
|
| return { |
| "strict_label": strict_label, |
| "partial_score": round(partial_score, 4), |
| "atom_judgments": norm_judgments, |
| "unsupported_or_contradictory": extras, |
| "absence_mismatch": absence_mismatch, |
| "overall_rationale": str(parsed.get("overall_rationale", "")) if isinstance(parsed, dict) else "", |
| } |
|
|
|
|
| def judge_atomic(qdata, hypothesis, rubric, deployment_name, api_version, args): |
| prompt = build_atomic_eval_prompt( |
| qdata["question_type"], |
| qdata["question"], |
| qdata["answer"], |
| hypothesis, |
| rubric, |
| ) |
| parsed, raw = call_json_llm(prompt, deployment_name, api_version, args) |
| scores = compute_atomic_scores(rubric, parsed) |
| return { |
| "model": args.eval_model_name, |
| "prompt_version": ATOMIC_PROMPT_VERSION, |
| "eval_mode": args.eval_mode, |
| "strict_label": scores["strict_label"], |
| "partial_score": scores["partial_score"], |
| "required_atoms": rubric["required_atoms"], |
| "atom_judgments": scores["atom_judgments"], |
| "unsupported_or_contradictory": scores["unsupported_or_contradictory"], |
| "absence_mismatch": scores["absence_mismatch"], |
| "overall_rationale": scores["overall_rationale"], |
| "raw_response": raw, |
| } |
|
|
|
|
| def should_skip_existing(existing_row, eval_mode): |
| if eval_mode == "legacy": |
| return "autoeval_label" in existing_row |
| atomic = existing_row.get("autoeval_atomic") |
| return bool(atomic and atomic.get("prompt_version") == ATOMIC_PROMPT_VERSION) |
|
|
|
|
| def safe_mean(values): |
| if not values: |
| return float("nan") |
| return float(np.mean(values)) |
|
|
|
|
| def print_legacy_summary(logs, qtype2acc): |
| labels = [1 if x["autoeval_label"]["label"] else 0 for x in logs if "autoeval_label" in x] |
| print("Accuracy:", round(safe_mean(labels), 4)) |
| for k, v in sorted(qtype2acc.items()): |
| print("\t{}: {} ({})".format(k, round(safe_mean(v), 4), len(v))) |
|
|
|
|
| def print_atomic_summary(logs, qtype2strict, qtype2partial, eval_mode): |
| strict_values = [ |
| 1 if x["autoeval_atomic"]["strict_label"] else 0 |
| for x in logs |
| if "autoeval_atomic" in x |
| ] |
| partial_values = [ |
| float(x["autoeval_atomic"]["partial_score"]) |
| for x in logs |
| if "autoeval_atomic" in x |
| ] |
| if eval_mode in ("strict", "both"): |
| print("Strict Accuracy:", round(safe_mean(strict_values), 4)) |
| for k, v in sorted(qtype2strict.items()): |
| print("\t{}: {} ({})".format(k, round(safe_mean(v), 4), len(v))) |
| if eval_mode in ("partial", "both"): |
| print("Partial Score:", round(safe_mean(partial_values), 4)) |
| for k, v in sorted(qtype2partial.items()): |
| print("\t{}: {} ({})".format(k, round(safe_mean(v), 4), len(v))) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--hyp_file", type=str, default=None) |
| parser.add_argument("--ref_file", type=str, required=True) |
| parser.add_argument("--eval_model_name", type=str, required=True) |
| parser.add_argument( |
| "--eval_mode", |
| type=str, |
| default="both", |
| choices=["legacy", "strict", "partial", "both"], |
| help="legacy uses the old yes/no judge; strict/partial/both use atomic JSON judging.", |
| ) |
| parser.add_argument("--rubric_file", type=str, default=None) |
| parser.add_argument("--build_rubric_only", action="store_true", default=False) |
| parser.add_argument("--force_rebuild_rubric", action="store_true", default=False) |
| parser.add_argument("--result_file", type=str, default=None) |
| parser.add_argument("--debug", action="store_true", default=False) |
| parser.add_argument("--vllm", action="store_true", default=False) |
| parser.add_argument("--tritonai", action="store_true", default=False, |
| help="Use OpenAI-compatible LiteLLM proxy (set TRITONAI_API_KEY env var)") |
| parser.add_argument("--nvidia", action="store_true", default=False, |
| help="Use NVIDIA inference API (set NV_API_KEY env var)") |
| parser.add_argument("--verbose", action=argparse.BooleanOptionalAction, default=True) |
| args = parser.parse_args() |
|
|
| if not args.build_rubric_only and not args.hyp_file: |
| parser.error("--hyp_file is required unless --build_rubric_only is set") |
|
|
| metric_model = args.eval_model_name |
| deployment_name, api_version = model_zoo[metric_model] |
| references = read_json_or_jsonl(args.ref_file) |
| qid2qdata = {entry["question_id"]: entry for entry in references} |
| qid2qtype = {entry["question_id"]: entry["question_type"] for entry in references} |
| qtypes = set(qid2qtype.values()) |
|
|
| rubric_data = None |
| rubric_file = args.rubric_file or default_rubric_file(args.ref_file) |
| if args.eval_mode != "legacy" or args.build_rubric_only: |
| rubric_data = load_rubric_file(rubric_file, args.ref_file) |
|
|
| if args.build_rubric_only: |
| for entry in tqdm(references, desc="building rubrics"): |
| get_or_build_rubric(entry, rubric_data, rubric_file, deployment_name, api_version, args) |
| print(f"Saved rubric file to {rubric_file}") |
| return |
|
|
| result_file = args.result_file or default_result_file(args.hyp_file, metric_model, args.eval_mode) |
| existing = read_existing_jsonl(result_file) |
| hypotheses = read_json_or_jsonl(args.hyp_file) |
|
|
| qtype2acc = {t: [] for t in qtypes} |
| qtype2strict = {t: [] for t in qtypes} |
| qtype2partial = {t: [] for t in qtypes} |
| logs = [] |
|
|
| for entry in tqdm(hypotheses): |
| qid = entry.get("question_id") |
| if qid not in qid2qtype: |
| if qid is not None: |
| print(f"Warning: skipping {qid} as it is not in reference data.") |
| continue |
|
|
| if qid in existing and should_skip_existing(existing[qid], args.eval_mode): |
| existing_row = existing[qid] |
| logs.append(existing_row) |
| qtype = qid2qtype[qid] |
| if args.eval_mode == "legacy": |
| label = existing_row["autoeval_label"]["label"] |
| qtype2acc[qtype].append(1 if label else 0) |
| else: |
| atomic = existing_row["autoeval_atomic"] |
| qtype2strict[qtype].append(1 if atomic["strict_label"] else 0) |
| qtype2partial[qtype].append(float(atomic["partial_score"])) |
| continue |
|
|
| qdata = qid2qdata[qid] |
| qtype = qdata["question_type"] |
| hyp = entry["hypothesis"] |
|
|
| if args.eval_mode == "legacy": |
| prompt = get_anscheck_prompt(qtype, qdata["question"], qdata["answer"], hyp) |
| completion = llm_call( |
| deployment_name, |
| api_version, |
| prompt, |
| debug=args.debug, |
| vllm=args.vllm, |
| tritonai=args.tritonai, |
| nvidia=args.nvidia, |
| ) |
| eval_response = completion.choices[0].message.content.strip() |
| last_line = next((l.strip().lower() for l in reversed(eval_response.splitlines()) if l.strip()), "") |
| label = last_line == "yes" or last_line.startswith("yes") |
| row = dict(entry) |
| row["autoeval_label"] = { |
| "model": metric_model, |
| "prompt_version": LEGACY_PROMPT_VERSION, |
| "label": label, |
| "raw_response": eval_response, |
| } |
| logs.append(row) |
| qtype2acc[qtype].append(1 if label else 0) |
| if args.verbose: |
| print(json.dumps({ |
| "question": qdata["question"], |
| "answer": qdata["answer"], |
| "hypothesis": hyp, |
| "autoeval_label": label, |
| }, indent=4), flush=True) |
| append_jsonl(result_file, row) |
| continue |
|
|
| rubric = get_or_build_rubric(qdata, rubric_data, rubric_file, deployment_name, api_version, args) |
| try: |
| atomic_eval = judge_atomic(qdata, hyp, rubric, deployment_name, api_version, args) |
| except ValueError as _judge_err: |
| print(f"[WARN] judge_atomic failed for {qdata['question_id']}, writing zero score: {_judge_err}", flush=True) |
| atoms = rubric.get("required_atoms", []) |
| atomic_eval = { |
| "model": deployment_name, |
| "prompt_version": ATOMIC_PROMPT_VERSION, |
| "eval_mode": args.eval_mode, |
| "strict_label": False, |
| "partial_score": 0.0, |
| "required_atoms": atoms, |
| "atom_judgments": [{"id": a["id"], "label": "error", "score": 0.0, "rationale": "judge parse error", "requirement": a.get("requirement", ""), "weight": a.get("weight", 1.0)} for a in atoms], |
| "unsupported_or_contradictory": [], |
| "absence_mismatch": False, |
| "overall_rationale": f"Skipped: judge JSON parse error ({_judge_err})", |
| } |
| row = dict(entry) |
| row["autoeval_atomic"] = atomic_eval |
| logs.append(row) |
| qtype2strict[qtype].append(1 if atomic_eval["strict_label"] else 0) |
| qtype2partial[qtype].append(float(atomic_eval["partial_score"])) |
| if args.verbose: |
| print(json.dumps({ |
| "question": qdata["question"], |
| "answer": qdata["answer"], |
| "hypothesis": hyp, |
| "strict_label": atomic_eval["strict_label"], |
| "partial_score": atomic_eval["partial_score"], |
| "atom_judgments": atomic_eval["atom_judgments"], |
| }, indent=4), flush=True) |
| append_jsonl(result_file, row) |
|
|
| if args.eval_mode == "legacy": |
| print_legacy_summary(logs, qtype2acc) |
| else: |
| print_atomic_summary(logs, qtype2strict, qtype2partial, args.eval_mode) |
| print(f"Rubric file: {rubric_file}") |
| print("Saved to", result_file) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|