#!/usr/bin/env python3 """Create deterministic dataset/retrieval-cache shards by question id.""" from __future__ import annotations import argparse import json from pathlib import Path from typing import Dict def split_json_dataset(dataset_path: Path, out_dir: Path, num_shards: int) -> Dict[str, int]: data = json.load(open(dataset_path, "r", encoding="utf-8")) qid_to_shard: Dict[str, int] = {} shards = [[] for _ in range(num_shards)] dataset_dir = out_dir / "dataset" dataset_dir.mkdir(parents=True, exist_ok=True) for idx, entry in enumerate(data): shard = idx % num_shards qid_to_shard[entry["question_id"]] = shard shards[shard].append(entry) for shard, rows in enumerate(shards): path = dataset_dir / f"shard_{shard:02d}.json" with open(path, "w", encoding="utf-8") as f: json.dump(rows, f, ensure_ascii=False, indent=2) f.write("\n") return qid_to_shard def split_jsonl_cache(cache_path: Path, out_dir: Path, qid_to_shard: Dict[str, int], num_shards: int) -> None: out_dir.mkdir(parents=True, exist_ok=True) handles = [ open(out_dir / f"shard_{shard:02d}.jsonl", "w", encoding="utf-8") for shard in range(num_shards) ] try: with open(cache_path, "r", encoding="utf-8") as f: for line in f: if not line.strip(): continue row = json.loads(line) qid = row["question_id"] shard = qid_to_shard.get(qid) if shard is not None: handles[shard].write(json.dumps(row, ensure_ascii=False) + "\n") finally: for handle in handles: handle.close() def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--dataset", required=True) parser.add_argument("--ret_cache", required=True) parser.add_argument("--semantic_cache", required=True) parser.add_argument("--out_dir", required=True) parser.add_argument("--num_shards", type=int, default=8) args = parser.parse_args() out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) qid_to_shard = split_json_dataset(Path(args.dataset), out_dir, args.num_shards) split_jsonl_cache(Path(args.ret_cache), out_dir / "ret_cache", qid_to_shard, args.num_shards) split_jsonl_cache(Path(args.semantic_cache), out_dir / "semantic_cache", qid_to_shard, args.num_shards) with open(out_dir / "qid_to_shard.json", "w", encoding="utf-8") as f: json.dump(qid_to_shard, f, ensure_ascii=False, indent=2) f.write("\n") print(f"Wrote {args.num_shards} shards to {out_dir}") if __name__ == "__main__": main()