File size: 36,176 Bytes
9c60174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
import argparse
import json
import os
import re
import time

import numpy as np
from openai import OpenAI
from tqdm import tqdm

try:
    from openai import AzureOpenAI
    from azure.identity import (
        AzureCliCredential,
        ChainedTokenCredential,
        ManagedIdentityCredential,
        get_bearer_token_provider,
    )

    AZURE_OAUTH_SCOPE = os.environ.get("AZURE_OAUTH_SCOPE", "")
    if AZURE_OAUTH_SCOPE:
        credential = get_bearer_token_provider(
            ChainedTokenCredential(
                AzureCliCredential(),
                ManagedIdentityCredential(),
            ),
            AZURE_OAUTH_SCOPE,
        )
    else:
        credential = None
except ImportError:
    AzureOpenAI = None
    credential = None

from model_zoo import model_zoo


# Azure OpenAI endpoint (set AZURE_OPENAI_ENDPOINT env var to your deployment URL).
endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT", "")
# OpenAI-compatible LiteLLM proxy URL (set LITELLM_BASE_URL env var to your proxy).
TRITONAI_BASE_URL = os.environ.get("LITELLM_BASE_URL", "")

ATOMIC_PROMPT_VERSION = "atomic-v1"
LEGACY_PROMPT_VERSION = "binary-v0"
ATOM_SCORES = {
    "correct": 1.0,
    "partially_correct": 0.5,
    "missing": 0.0,
    "incorrect": 0.0,
}


def _retryable_status(e) -> "int | None":
    status = getattr(e, "status_code", None) or getattr(e, "http_status", None)
    if status in (429, 500, 503, 403):
        return status
    resp = getattr(e, "response", None)
    if resp is not None and getattr(resp, "status_code", None) in (429, 500, 503):
        return resp.status_code
    msg = str(e).lower()
    if "429" in msg or "rate limit" in msg:
        return 429
    if "500" in msg or "internal server error" in msg:
        return 500
    if "503" in msg or "api configuration unavailable" in msg:
        return 503
    return None


def parse_json_object(text):
    text = (text or "").strip()
    if text.startswith("```"):
        text = re.sub(r"^```(?:json)?", "", text).strip()
        text = re.sub(r"```$", "", text).strip()
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        start = text.find("{")
        end = text.rfind("}") + 1
        if start >= 0 and end > start:
            return json.loads(text[start:end])
        raise


def sanitize_model_name(name):
    return re.sub(r"[^A-Za-z0-9_.-]+", "_", name)


def read_json_or_jsonl(path):
    try:
        with open(path, "r", encoding="utf-8") as f:
            return json.load(f)
    except json.JSONDecodeError:
        with open(path, "r", encoding="utf-8") as f:
            return [json.loads(line) for line in f if line.strip()]


def read_existing_jsonl(path):
    if not path or not os.path.exists(path):
        return {}
    rows = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            obj = json.loads(line)
            if "question_id" in obj:
                rows[obj["question_id"]] = obj
    return rows


def write_json(path, obj):
    tmp_path = path + ".tmp"
    with open(tmp_path, "w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=2)
        f.write("\n")
    os.replace(tmp_path, path)


def append_jsonl(path, row):
    with open(path, "a", encoding="utf-8") as f:
        print(json.dumps(row, ensure_ascii=False), file=f, flush=True)


def default_result_file(hyp_file, metric_model, eval_mode):
    if eval_mode == "legacy":
        return f"{hyp_file}.eval-results-{metric_model}"
    model_tag = sanitize_model_name(metric_model)
    return f"{hyp_file}.eval-results-{model_tag}-{ATOMIC_PROMPT_VERSION}-{eval_mode}"


def default_rubric_file(ref_file):
    return f"{ref_file}.{ATOMIC_PROMPT_VERSION}.rubric.json"


def load_rubric_file(path, ref_file):
    if not path or not os.path.exists(path):
        return {
            "prompt_version": ATOMIC_PROMPT_VERSION,
            "source_ref_file": ref_file,
            "rubrics": {},
        }
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    if "rubrics" in data:
        data.setdefault("prompt_version", ATOMIC_PROMPT_VERSION)
        data.setdefault("source_ref_file", ref_file)
        return data
    return {
        "prompt_version": ATOMIC_PROMPT_VERSION,
        "source_ref_file": ref_file,
        "rubrics": data,
    }


def question_type_guidance(task):
    if task in ["Information Absence"]:
        return (
            "The correct answer is that the information is unavailable, absent, "
            "not yet known, or not supported. A response that gives a concrete "
            "answer instead of abstaining is incorrect."
        )
    if task in ["Aggregation", "single-session-user", "single-session-assistant", "multi-session"]:
        return (
            "Check every requested item, count, list member, and named fact. "
            "Exact counts and required list coverage matter. Extra material "
            "facts that change the answer should be flagged."
        )
    if task in ["Aggregation + Temporal"]:
        return (
            "Check both the aggregated facts and their time/order associations. "
            "An answer can name the right item but still be wrong if the timing, "
            "ordering, before/after relation, or year is wrong."
        )
    if task in ["Temporal Reasoning", "temporal-reasoning"]:
        return (
            "Check the specific time, date, year, sequence, duration, or temporal "
            "relationship asked for. Accept +/-1 only for day/week/month duration "
            "counts, not for years, event identity, or ordering."
        )
    if task in ["Knowledge Update", "knowledge-update"]:
        return (
            "Check the current or most recent state. Historical context is fine "
            "only if the final/current state is clearly correct. Outdated states "
            "presented as current are incorrect."
        )
    if task == "single-session-preference":
        return (
            "Check whether the response recalls and applies the stated preference. "
            "Do not require unnecessary verbosity, but contradictions are incorrect."
        )
    return "Check whether the model response satisfies all required facts in the reference answer."


def get_anscheck_prompt(task, question, answer, response):
    """Legacy binary yes/no prompt kept for backward compatibility."""
    if task in ["Information Absence"]:
        template = """You are evaluating whether a model correctly identified that a question is unanswerable given the available personal chat history.

Question: {question}

Explanation of why it is unanswerable: {answer}

Model Response: {response}

Evaluation criteria:
- CORRECT if the model explicitly states that the information is not available, insufficient, or that the question cannot be answered from the provided context.
- INCORRECT if the model fabricates an answer or fails to acknowledge the unanswerable nature of the question.
- The model does not need to use the exact word "unanswerable" -- expressing uncertainty or lack of information is sufficient.

Briefly explain your reasoning (1-2 sentences), then on the last line write only: yes or no"""
    elif task in ["Aggregation", "single-session-user", "single-session-assistant", "multi-session"]:
        template = """You are evaluating whether a model correctly answered a question that requires aggregating specific facts from a user's personal chat history.

Question: {question}

Reference Answer: {answer}

Model Response: {response}

Evaluation criteria:
- CORRECT if the response identifies all key items, facts, or counts present in the reference answer, even if phrased differently or with added context.
- INCORRECT if the response:
  - States a wrong count (e.g., says "5" when the answer is "3")
  - Omits one or more key items/facts listed in the reference
  - Lists mostly wrong items even if the count is right
- Partial answers that cover only a subset of required items are INCORRECT.
- Verbose responses are acceptable as long as all reference items are present within them.
- If the response contains correct items but also lists additional plausible-sounding but unverified items beyond the reference, this does NOT make it incorrect -- evaluate only whether the reference items are covered.
- Semantic equivalence counts as correct (e.g., "RSI" = "Repetitive Strain Injury").

Briefly explain your reasoning (2-3 sentences), then on the last line write only: yes or no"""
    elif task in ["Aggregation + Temporal"]:
        template = """You are evaluating whether a model correctly answered a question that requires both aggregating facts AND reasoning about their temporal order or time associations from a user's personal chat history.

Question: {question}

Reference Answer: {answer}

Model Response: {response}

Evaluation criteria:
- CORRECT if the response captures both:
  (a) all key events/facts listed in the reference, and
  (b) their correct temporal associations (ordering, time periods, or "before/after" relationships).
- INCORRECT if the response:
  - Omits one or more key events or facts from the reference
  - Gets the temporal ordering or time associations wrong
  - Captures only the content without the temporal aspects, or vice versa
- Responses that describe the correct progression/sequence in different words are acceptable.
- Partial answers covering only some events or ignoring time aspects are INCORRECT.
- Minor wording differences or additional explanatory context are acceptable.

Briefly explain your reasoning (2-3 sentences), then on the last line write only: yes or no"""
    elif task in ["Temporal Reasoning", "temporal-reasoning"]:
        template = """You are evaluating whether a model correctly answered a question about temporal relationships in a user's personal chat history.

Question: {question}

Reference Answer: {answer}

Model Response: {response}

Evaluation criteria:
- CORRECT if the response correctly identifies the specific time, date, year, sequence, or temporal relationship asked about.
- INCORRECT if the response states a wrong year, wrong sequence, wrong temporal relationship, or misidentifies which event came first/last.
- Off-by-one tolerance: if the question asks for a count of days, weeks, or months, accept answers that differ by +/-1. This tolerance does NOT apply to years or to identifying specific events/artifacts.
- Responses that correctly identify the fact but with verbose context are acceptable.
- If the response hedges but still states the correct answer, it is correct.

Briefly explain your reasoning (2-3 sentences), then on the last line write only: yes or no"""
    elif task in ["Knowledge Update", "knowledge-update"]:
        template = """You are evaluating whether a model correctly answered a question about the most recent or current state of something that changed over time in a user's personal chat history.

Question: {question}

Reference Answer: {answer}

Model Response: {response}

Evaluation criteria:
- CORRECT if the response correctly identifies the most recent/current state as described in the reference answer.
- The response may include earlier historical states as context -- this is acceptable as long as the current/final state is correctly identified and clearly stated.
- INCORRECT if the response:
  - States an outdated/superseded state as the current one
  - Omits the current state entirely
  - Correctly describes history but draws the wrong conclusion about what the current state is
- Semantic equivalence counts (e.g., "flexitarian" and "semi-vegetarian diet with occasional meat" are equivalent if contextually clear).

Briefly explain your reasoning (2-3 sentences), then on the last line write only: yes or no"""
    elif task == "single-session-preference":
        template = """You are evaluating whether a model correctly answered a personalized question based on a user's stated preferences from their chat history.

Question: {question}

Reference Rubric: {answer}

Model Response: {response}

Evaluation criteria:
- CORRECT if the response recalls and applies the user's personal preferences correctly, even if not covering every point in the rubric.
- INCORRECT if the response ignores, contradicts, or misremembers the user's preferences.

Briefly explain your reasoning (1-2 sentences), then on the last line write only: yes or no"""
    else:
        template = """You are evaluating whether a model's response correctly answers a question based on a user's personal chat history.

Question: {question}

Reference Answer: {answer}

Model Response: {response}

Is the response correct? It is correct if it contains all key information from the reference answer, even if phrased differently.

Briefly explain your reasoning (1-2 sentences), then on the last line write only: yes or no"""
    return template.format(question=question, answer=answer, response=response)


def build_rubric_prompt(task, question, answer):
    return f"""You are creating an atomic grading rubric for an open-ended QA benchmark.

Question type: {task}
Question:
{question}

Reference answer:
{answer}

Question-type guidance:
{question_type_guidance(task)}

Decompose the reference answer into the smallest independently checkable requirements needed to answer the question.

Rules:
- Each atom should be a single required answer unit: an entity, count, date/year, order relation, current-state conclusion, or abstention requirement.
- If an entity and its temporal relation are inseparable for correctness, keep them in the same atom.
- For list/count questions, include an atom for the exact count when the question asks "how many", and atoms for each required listed item when the item identities matter.
- For Information Absence, usually use one atom requiring the response to clearly state that the information is unavailable/insufficient/not discussed, and add a strict note that concrete fabricated answers are wrong.
- Do not include supporting evidence requirements or session IDs unless the question explicitly asks for them.
- Weights should normally be 1.0. Use a higher weight only when one atom is clearly the main answer and other atoms are minor.

Return JSON only with this schema:
{{
  "required_atoms": [
    {{
      "id": "a1",
      "requirement": "short, specific grading requirement",
      "weight": 1.0
    }}
  ],
  "strict_notes": ["short note about exactness, ordering, abstention, or hallucination handling"]
}}"""


def build_atomic_eval_prompt(task, question, answer, response, rubric):
    rubric_str = json.dumps(
        {
            "required_atoms": rubric["required_atoms"],
            "strict_notes": rubric.get("strict_notes", []),
        },
        ensure_ascii=False,
        indent=2,
    )
    return f"""You are an LLM-as-a-judge evaluating one model response against a reference answer.

Question type: {task}
Question:
{question}

Reference answer:
{answer}

Model response:
{response}

Atomic grading rubric:
{rubric_str}

Question-type guidance:
{question_type_guidance(task)}

Judge each atom independently.

Atom labels:
- correct: the response fully satisfies this atom, allowing semantic paraphrase.
- partially_correct: the response gets the main idea but is incomplete or slightly underspecified. Use this sparingly. Do not use it for wrong counts, wrong years, wrong named entities, wrong ordering, or a concrete answer to an Information Absence question.
- missing: the response does not address this atom.
- incorrect: the response contradicts this atom or gives the wrong count/entity/date/order/current state.

Also identify unsupported_or_contradictory material:
- severity "material": extra answer content that changes the final answer, adds extra items to an exact list/count, gives an outdated current state, fabricates a concrete answer for Information Absence, or contradicts any atom.
- severity "minor": harmless context or extra explanation that does not change the answer.

Return JSON only with this schema:
{{
  "atom_judgments": [
    {{
      "id": "a1",
      "label": "correct|partially_correct|missing|incorrect",
      "rationale": "brief reason"
    }}
  ],
  "unsupported_or_contradictory": [
    {{
      "text": "extra or contradictory claim",
      "severity": "minor|material",
      "rationale": "brief reason"
    }}
  ],
  "absence_mismatch": false,
  "overall_rationale": "one or two sentence summary"
}}"""


def llm_call(
    deployment_name: str,
    api_version: str,
    _prompt: str,
    debug: bool = False,
    vllm: bool = False,
    tritonai: bool = False,
    nvidia: bool = False,
):
    if nvidia:
        client = OpenAI(
            api_key=os.getenv("NV_API_KEY"),
            base_url="https://inference-api.nvidia.com",
        )
        while True:
            try:
                return client.chat.completions.create(
                    model=deployment_name,
                    messages=[{"role": "user", "content": _prompt}],
                )
            except Exception as e:
                st = _retryable_status(e)
                if st in (429, 500, 503, 403):
                    print(f"[WARN] HTTP {st} from NVIDIA API; sleeping 60s then retrying...", flush=True)
                    time.sleep(60)
                    continue
                print("One exception captured", repr(e), flush=True)
                raise

    if tritonai:
        client = OpenAI(
            api_key=os.getenv("TRITONAI_API_KEY"),
            base_url=TRITONAI_BASE_URL,
        )
        while True:
            try:
                return client.chat.completions.create(
                    model=deployment_name,
                    messages=[{"role": "user", "content": _prompt}],
                )
            except Exception as e:
                st = _retryable_status(e)
                if st in (429, 500, 503, 403):
                    print(f"[WARN] HTTP {st} from LiteLLM proxy; sleeping 60s then retrying...", flush=True)
                    time.sleep(60)
                    continue
                print("One exception captured", repr(e), flush=True)
                raise

    if deployment_name.startswith("claude-"):
        import anthropic

        client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
        while True:
            try:
                msg = client.messages.create(
                    model=deployment_name,
                    max_tokens=1024,
                    messages=[{"role": "user", "content": _prompt}],
                )

                class _Choice:
                    class _Msg:
                        def __init__(self, text):
                            self.content = text

                    def __init__(self, text):
                        self.message = self._Msg(text)

                class _Completion:
                    def __init__(self, text):
                        self.choices = [_Choice(text)]

                return _Completion(msg.content[0].text)
            except Exception as e:
                st = _retryable_status(e)
                if st in (429, 500, 503, 403):
                    print(f"[WARN] HTTP {st} from Anthropic; sleeping 60s then retrying...", flush=True)
                    time.sleep(60)
                    continue
                print("One exception captured", repr(e), flush=True)
                raise

    if vllm:
        client = OpenAI(
            base_url=os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1"),
            api_key=os.getenv("VLLM_API_KEY", "EMPTY"),
        )
    elif debug:
        client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
    else:
        client = AzureOpenAI(
            azure_endpoint=endpoint,
            azure_ad_token_provider=credential,
            api_version=api_version,
        )

    kwargs = {
        "model": deployment_name,
        "messages": [{"role": "system", "content": _prompt}],
    }
    while True:
        try:
            return client.chat.completions.create(**kwargs)
        except Exception as e:
            st = _retryable_status(e)
            if st in (429, 500, 503, 403):
                print(f"[WARN] HTTP {st} from LLM; sleeping 120s then retrying...", flush=True)
                time.sleep(120)
                continue
            print("One exception captured", repr(e), flush=True)
            raise


def call_json_llm(prompt, deployment_name, api_version, args, max_retries=3):
    last_error = None
    for attempt in range(max_retries):
        completion = llm_call(
            deployment_name,
            api_version,
            prompt,
            debug=args.debug,
            vllm=args.vllm,
            tritonai=args.tritonai,
            nvidia=args.nvidia,
        )
        content = completion.choices[0].message.content.strip()
        try:
            return parse_json_object(content), content
        except Exception as e:
            last_error = e
            if attempt < max_retries - 1:
                print(f"[WARN] Failed to parse judge JSON; retrying ({attempt + 1}/{max_retries})", flush=True)
                time.sleep(2)
    raise ValueError(f"Failed to parse JSON response from judge: {last_error}")


def fallback_rubric(qid, task, question, answer):
    return {
        "question_id": qid,
        "question_type": task,
        "question": question,
        "reference_answer": answer,
        "required_atoms": [
            {
                "id": "a1",
                "requirement": f"Response must correctly answer the question according to the reference answer: {answer}",
                "weight": 1.0,
            }
        ],
        "strict_notes": ["Fallback single-atom rubric produced because the generated rubric was invalid."],
        "prompt_version": ATOMIC_PROMPT_VERSION,
    }


def normalize_rubric(qid, task, question, answer, parsed):
    atoms = parsed.get("required_atoms", []) if isinstance(parsed, dict) else []
    norm_atoms = []
    for idx, atom in enumerate(atoms, start=1):
        if not isinstance(atom, dict):
            continue
        requirement = str(atom.get("requirement", "")).strip()
        if not requirement:
            continue
        atom_id = str(atom.get("id", f"a{idx}")).strip() or f"a{idx}"
        try:
            weight = float(atom.get("weight", 1.0))
        except (TypeError, ValueError):
            weight = 1.0
        if weight <= 0:
            weight = 1.0
        norm_atoms.append({"id": atom_id, "requirement": requirement, "weight": weight})
    if not norm_atoms:
        return fallback_rubric(qid, task, question, answer)
    strict_notes = parsed.get("strict_notes", [])
    if not isinstance(strict_notes, list):
        strict_notes = [str(strict_notes)]
    return {
        "question_id": qid,
        "question_type": task,
        "question": question,
        "reference_answer": answer,
        "required_atoms": norm_atoms,
        "strict_notes": [str(x) for x in strict_notes],
        "prompt_version": ATOMIC_PROMPT_VERSION,
    }


def get_or_build_rubric(qdata, rubric_data, rubric_file, deployment_name, api_version, args):
    qid = qdata["question_id"]
    existing = rubric_data["rubrics"].get(qid)
    if (
        existing
        and existing.get("prompt_version") == ATOMIC_PROMPT_VERSION
        and existing.get("required_atoms")
        and not args.force_rebuild_rubric
    ):
        return existing

    task = qdata["question_type"]
    prompt = build_rubric_prompt(task, qdata["question"], qdata["answer"])
    try:
        parsed, raw = call_json_llm(prompt, deployment_name, api_version, args)
        rubric = normalize_rubric(qid, task, qdata["question"], qdata["answer"], parsed)
        rubric["rubric_raw_response"] = raw
    except Exception as e:
        print(f"[WARN] Falling back to single-atom rubric for {qid}: {e}", flush=True)
        rubric = fallback_rubric(qid, task, qdata["question"], qdata["answer"])
    rubric_data["rubrics"][qid] = rubric
    write_json(rubric_file, rubric_data)
    return rubric


def compute_atomic_scores(rubric, parsed):
    atoms = rubric.get("required_atoms", [])
    judgments_by_id = {}
    raw_judgments = parsed.get("atom_judgments", []) if isinstance(parsed, dict) else []
    if isinstance(raw_judgments, list):
        for judgment in raw_judgments:
            if not isinstance(judgment, dict):
                continue
            atom_id = str(judgment.get("id", "")).strip()
            label = str(judgment.get("label", "")).strip()
            if label not in ATOM_SCORES:
                label = "incorrect"
            judgments_by_id[atom_id] = {
                "id": atom_id,
                "label": label,
                "score": ATOM_SCORES[label],
                "rationale": str(judgment.get("rationale", "")),
            }

    norm_judgments = []
    weighted_score = 0.0
    total_weight = 0.0
    for atom in atoms:
        atom_id = atom["id"]
        weight = float(atom.get("weight", 1.0))
        judgment = judgments_by_id.get(
            atom_id,
            {"id": atom_id, "label": "missing", "score": 0.0, "rationale": "No judgment returned."},
        )
        judgment["requirement"] = atom["requirement"]
        judgment["weight"] = weight
        norm_judgments.append(judgment)
        weighted_score += judgment["score"] * weight
        total_weight += weight

    extras = parsed.get("unsupported_or_contradictory", []) if isinstance(parsed, dict) else []
    if not isinstance(extras, list):
        extras = []
    material_extras = [
        x for x in extras
        if isinstance(x, dict) and str(x.get("severity", "")).strip() == "material"
    ]
    absence_mismatch = bool(parsed.get("absence_mismatch", False)) if isinstance(parsed, dict) else False

    strict_label = (
        bool(norm_judgments)
        and all(j["label"] == "correct" for j in norm_judgments)
        and not material_extras
        and not absence_mismatch
    )
    partial_score = weighted_score / total_weight if total_weight > 0 else 0.0
    if absence_mismatch:
        partial_score = 0.0
    elif material_extras and partial_score > 0.8:
        partial_score = 0.8

    return {
        "strict_label": strict_label,
        "partial_score": round(partial_score, 4),
        "atom_judgments": norm_judgments,
        "unsupported_or_contradictory": extras,
        "absence_mismatch": absence_mismatch,
        "overall_rationale": str(parsed.get("overall_rationale", "")) if isinstance(parsed, dict) else "",
    }


def judge_atomic(qdata, hypothesis, rubric, deployment_name, api_version, args):
    prompt = build_atomic_eval_prompt(
        qdata["question_type"],
        qdata["question"],
        qdata["answer"],
        hypothesis,
        rubric,
    )
    parsed, raw = call_json_llm(prompt, deployment_name, api_version, args)
    scores = compute_atomic_scores(rubric, parsed)
    return {
        "model": args.eval_model_name,
        "prompt_version": ATOMIC_PROMPT_VERSION,
        "eval_mode": args.eval_mode,
        "strict_label": scores["strict_label"],
        "partial_score": scores["partial_score"],
        "required_atoms": rubric["required_atoms"],
        "atom_judgments": scores["atom_judgments"],
        "unsupported_or_contradictory": scores["unsupported_or_contradictory"],
        "absence_mismatch": scores["absence_mismatch"],
        "overall_rationale": scores["overall_rationale"],
        "raw_response": raw,
    }


def should_skip_existing(existing_row, eval_mode):
    if eval_mode == "legacy":
        return "autoeval_label" in existing_row
    atomic = existing_row.get("autoeval_atomic")
    return bool(atomic and atomic.get("prompt_version") == ATOMIC_PROMPT_VERSION)


def safe_mean(values):
    if not values:
        return float("nan")
    return float(np.mean(values))


def print_legacy_summary(logs, qtype2acc):
    labels = [1 if x["autoeval_label"]["label"] else 0 for x in logs if "autoeval_label" in x]
    print("Accuracy:", round(safe_mean(labels), 4))
    for k, v in sorted(qtype2acc.items()):
        print("\t{}: {} ({})".format(k, round(safe_mean(v), 4), len(v)))


def print_atomic_summary(logs, qtype2strict, qtype2partial, eval_mode):
    strict_values = [
        1 if x["autoeval_atomic"]["strict_label"] else 0
        for x in logs
        if "autoeval_atomic" in x
    ]
    partial_values = [
        float(x["autoeval_atomic"]["partial_score"])
        for x in logs
        if "autoeval_atomic" in x
    ]
    if eval_mode in ("strict", "both"):
        print("Strict Accuracy:", round(safe_mean(strict_values), 4))
        for k, v in sorted(qtype2strict.items()):
            print("\t{}: {} ({})".format(k, round(safe_mean(v), 4), len(v)))
    if eval_mode in ("partial", "both"):
        print("Partial Score:", round(safe_mean(partial_values), 4))
        for k, v in sorted(qtype2partial.items()):
            print("\t{}: {} ({})".format(k, round(safe_mean(v), 4), len(v)))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--hyp_file", type=str, default=None)
    parser.add_argument("--ref_file", type=str, required=True)
    parser.add_argument("--eval_model_name", type=str, required=True)
    parser.add_argument(
        "--eval_mode",
        type=str,
        default="both",
        choices=["legacy", "strict", "partial", "both"],
        help="legacy uses the old yes/no judge; strict/partial/both use atomic JSON judging.",
    )
    parser.add_argument("--rubric_file", type=str, default=None)
    parser.add_argument("--build_rubric_only", action="store_true", default=False)
    parser.add_argument("--force_rebuild_rubric", action="store_true", default=False)
    parser.add_argument("--result_file", type=str, default=None)
    parser.add_argument("--debug", action="store_true", default=False)
    parser.add_argument("--vllm", action="store_true", default=False)
    parser.add_argument("--tritonai", action="store_true", default=False,
                        help="Use OpenAI-compatible LiteLLM proxy (set TRITONAI_API_KEY env var)")
    parser.add_argument("--nvidia", action="store_true", default=False,
                        help="Use NVIDIA inference API (set NV_API_KEY env var)")
    parser.add_argument("--verbose", action=argparse.BooleanOptionalAction, default=True)
    args = parser.parse_args()

    if not args.build_rubric_only and not args.hyp_file:
        parser.error("--hyp_file is required unless --build_rubric_only is set")

    metric_model = args.eval_model_name
    deployment_name, api_version = model_zoo[metric_model]
    references = read_json_or_jsonl(args.ref_file)
    qid2qdata = {entry["question_id"]: entry for entry in references}
    qid2qtype = {entry["question_id"]: entry["question_type"] for entry in references}
    qtypes = set(qid2qtype.values())

    rubric_data = None
    rubric_file = args.rubric_file or default_rubric_file(args.ref_file)
    if args.eval_mode != "legacy" or args.build_rubric_only:
        rubric_data = load_rubric_file(rubric_file, args.ref_file)

    if args.build_rubric_only:
        for entry in tqdm(references, desc="building rubrics"):
            get_or_build_rubric(entry, rubric_data, rubric_file, deployment_name, api_version, args)
        print(f"Saved rubric file to {rubric_file}")
        return

    result_file = args.result_file or default_result_file(args.hyp_file, metric_model, args.eval_mode)
    existing = read_existing_jsonl(result_file)
    hypotheses = read_json_or_jsonl(args.hyp_file)

    qtype2acc = {t: [] for t in qtypes}
    qtype2strict = {t: [] for t in qtypes}
    qtype2partial = {t: [] for t in qtypes}
    logs = []

    for entry in tqdm(hypotheses):
        qid = entry.get("question_id")
        if qid not in qid2qtype:
            if qid is not None:
                print(f"Warning: skipping {qid} as it is not in reference data.")
            continue

        if qid in existing and should_skip_existing(existing[qid], args.eval_mode):
            existing_row = existing[qid]
            logs.append(existing_row)
            qtype = qid2qtype[qid]
            if args.eval_mode == "legacy":
                label = existing_row["autoeval_label"]["label"]
                qtype2acc[qtype].append(1 if label else 0)
            else:
                atomic = existing_row["autoeval_atomic"]
                qtype2strict[qtype].append(1 if atomic["strict_label"] else 0)
                qtype2partial[qtype].append(float(atomic["partial_score"]))
            continue

        qdata = qid2qdata[qid]
        qtype = qdata["question_type"]
        hyp = entry["hypothesis"]

        if args.eval_mode == "legacy":
            prompt = get_anscheck_prompt(qtype, qdata["question"], qdata["answer"], hyp)
            completion = llm_call(
                deployment_name,
                api_version,
                prompt,
                debug=args.debug,
                vllm=args.vllm,
                tritonai=args.tritonai,
                nvidia=args.nvidia,
            )
            eval_response = completion.choices[0].message.content.strip()
            last_line = next((l.strip().lower() for l in reversed(eval_response.splitlines()) if l.strip()), "")
            label = last_line == "yes" or last_line.startswith("yes")
            row = dict(entry)
            row["autoeval_label"] = {
                "model": metric_model,
                "prompt_version": LEGACY_PROMPT_VERSION,
                "label": label,
                "raw_response": eval_response,
            }
            logs.append(row)
            qtype2acc[qtype].append(1 if label else 0)
            if args.verbose:
                print(json.dumps({
                    "question": qdata["question"],
                    "answer": qdata["answer"],
                    "hypothesis": hyp,
                    "autoeval_label": label,
                }, indent=4), flush=True)
            append_jsonl(result_file, row)
            continue

        rubric = get_or_build_rubric(qdata, rubric_data, rubric_file, deployment_name, api_version, args)
        try:
            atomic_eval = judge_atomic(qdata, hyp, rubric, deployment_name, api_version, args)
        except ValueError as _judge_err:
            print(f"[WARN] judge_atomic failed for {qdata['question_id']}, writing zero score: {_judge_err}", flush=True)
            atoms = rubric.get("required_atoms", [])
            atomic_eval = {
                "model": deployment_name,
                "prompt_version": ATOMIC_PROMPT_VERSION,
                "eval_mode": args.eval_mode,
                "strict_label": False,
                "partial_score": 0.0,
                "required_atoms": atoms,
                "atom_judgments": [{"id": a["id"], "label": "error", "score": 0.0, "rationale": "judge parse error", "requirement": a.get("requirement", ""), "weight": a.get("weight", 1.0)} for a in atoms],
                "unsupported_or_contradictory": [],
                "absence_mismatch": False,
                "overall_rationale": f"Skipped: judge JSON parse error ({_judge_err})",
            }
        row = dict(entry)
        row["autoeval_atomic"] = atomic_eval
        logs.append(row)
        qtype2strict[qtype].append(1 if atomic_eval["strict_label"] else 0)
        qtype2partial[qtype].append(float(atomic_eval["partial_score"]))
        if args.verbose:
            print(json.dumps({
                "question": qdata["question"],
                "answer": qdata["answer"],
                "hypothesis": hyp,
                "strict_label": atomic_eval["strict_label"],
                "partial_score": atomic_eval["partial_score"],
                "atom_judgments": atomic_eval["atom_judgments"],
            }, indent=4), flush=True)
        append_jsonl(result_file, row)

    if args.eval_mode == "legacy":
        print_legacy_summary(logs, qtype2acc)
    else:
        print_atomic_summary(logs, qtype2strict, qtype2partial, args.eval_mode)
        print(f"Rubric file: {rubric_file}")
    print("Saved to", result_file)


if __name__ == "__main__":
    main()