from __future__ import annotations import argparse import sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(ROOT)) from env.llmail_curator import ( DIFFICULTIES, curate_rows, load_llmail_contexts, load_rows_from_hf, load_rows_from_jsonl, write_curated_scenarios, ) DEFAULT_OUTPUT_DIR = ROOT / "env" / "data" / "scenarios" / "curated" def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Curate LLMail submissions into local easy/medium/hard scenario files for the IDPI mail benchmark." ) parser.add_argument("--phase", action="append", choices=["Phase1", "Phase2"], help="Repeat to choose one or both Hugging Face splits. Defaults to both.") parser.add_argument("--input-jsonl-phase1", type=Path, help="Optional local Phase1 JSONL file.") parser.add_argument("--input-jsonl-phase2", type=Path, help="Optional local Phase2 JSONL file.") parser.add_argument("--scenarios-json", type=Path, help="Optional local copy of LLMail scenarios.json.") parser.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT_DIR) parser.add_argument("--prefix", default="") parser.add_argument("--max-per-difficulty", type=int, default=50) parser.add_argument("--sample-limit", type=int, default=None, help="Optional cap per split for quick dry runs.") parser.add_argument("--seed", type=int, default=42) parser.add_argument("--overwrite", action="store_true") return parser.parse_args() def rows_for_phase(phase: str, args: argparse.Namespace): if phase == "Phase1" and args.input_jsonl_phase1: return load_rows_from_jsonl(args.input_jsonl_phase1, sample_limit=args.sample_limit) if phase == "Phase2" and args.input_jsonl_phase2: return load_rows_from_jsonl(args.input_jsonl_phase2, sample_limit=args.sample_limit) return load_rows_from_hf(phase, sample_limit=args.sample_limit) def main() -> None: args = parse_args() phases = args.phase or ["Phase1", "Phase2"] contexts = load_llmail_contexts(local_path=args.scenarios_json) merged = {difficulty: [] for difficulty in DIFFICULTIES} for phase in phases: curated = curate_rows( rows_for_phase(phase, args), phase=phase, contexts=contexts, max_per_difficulty=args.max_per_difficulty, seed=args.seed, ) for difficulty in DIFFICULTIES: merged[difficulty].extend(curated[difficulty]) written = write_curated_scenarios( merged, output_dir=args.output_dir, prefix=args.prefix, overwrite=args.overwrite, ) for path in written: with path.open("r", encoding="utf-8-sig") as handle: count = sum(1 for _ in handle) print(f"wrote {count:>3} scenarios -> {path}") if __name__ == "__main__": main()