File size: 2,923 Bytes
3b452a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from __future__ import annotations

import argparse
import sys
from pathlib import Path

ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))

from env.llmail_curator import (
    DIFFICULTIES,
    curate_rows,
    load_llmail_contexts,
    load_rows_from_hf,
    load_rows_from_jsonl,
    write_curated_scenarios,
)

DEFAULT_OUTPUT_DIR = ROOT / "env" / "data" / "scenarios" / "curated"


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Curate LLMail submissions into local easy/medium/hard scenario files for the IDPI mail benchmark."
    )
    parser.add_argument("--phase", action="append", choices=["Phase1", "Phase2"], help="Repeat to choose one or both Hugging Face splits. Defaults to both.")
    parser.add_argument("--input-jsonl-phase1", type=Path, help="Optional local Phase1 JSONL file.")
    parser.add_argument("--input-jsonl-phase2", type=Path, help="Optional local Phase2 JSONL file.")
    parser.add_argument("--scenarios-json", type=Path, help="Optional local copy of LLMail scenarios.json.")
    parser.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT_DIR)
    parser.add_argument("--prefix", default="")
    parser.add_argument("--max-per-difficulty", type=int, default=50)
    parser.add_argument("--sample-limit", type=int, default=None, help="Optional cap per split for quick dry runs.")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--overwrite", action="store_true")
    return parser.parse_args()


def rows_for_phase(phase: str, args: argparse.Namespace):
    if phase == "Phase1" and args.input_jsonl_phase1:
        return load_rows_from_jsonl(args.input_jsonl_phase1, sample_limit=args.sample_limit)
    if phase == "Phase2" and args.input_jsonl_phase2:
        return load_rows_from_jsonl(args.input_jsonl_phase2, sample_limit=args.sample_limit)
    return load_rows_from_hf(phase, sample_limit=args.sample_limit)


def main() -> None:
    args = parse_args()
    phases = args.phase or ["Phase1", "Phase2"]
    contexts = load_llmail_contexts(local_path=args.scenarios_json)

    merged = {difficulty: [] for difficulty in DIFFICULTIES}
    for phase in phases:
        curated = curate_rows(
            rows_for_phase(phase, args),
            phase=phase,
            contexts=contexts,
            max_per_difficulty=args.max_per_difficulty,
            seed=args.seed,
        )
        for difficulty in DIFFICULTIES:
            merged[difficulty].extend(curated[difficulty])

    written = write_curated_scenarios(
        merged,
        output_dir=args.output_dir,
        prefix=args.prefix,
        overwrite=args.overwrite,
    )

    for path in written:
        with path.open("r", encoding="utf-8-sig") as handle:
            count = sum(1 for _ in handle)
        print(f"wrote {count:>3} scenarios -> {path}")


if __name__ == "__main__":
    main()