File size: 2,744 Bytes
9c60174 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | #!/usr/bin/env python3
"""Create deterministic dataset/retrieval-cache shards by question id."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Dict
def split_json_dataset(dataset_path: Path, out_dir: Path, num_shards: int) -> Dict[str, int]:
data = json.load(open(dataset_path, "r", encoding="utf-8"))
qid_to_shard: Dict[str, int] = {}
shards = [[] for _ in range(num_shards)]
dataset_dir = out_dir / "dataset"
dataset_dir.mkdir(parents=True, exist_ok=True)
for idx, entry in enumerate(data):
shard = idx % num_shards
qid_to_shard[entry["question_id"]] = shard
shards[shard].append(entry)
for shard, rows in enumerate(shards):
path = dataset_dir / f"shard_{shard:02d}.json"
with open(path, "w", encoding="utf-8") as f:
json.dump(rows, f, ensure_ascii=False, indent=2)
f.write("\n")
return qid_to_shard
def split_jsonl_cache(cache_path: Path, out_dir: Path, qid_to_shard: Dict[str, int], num_shards: int) -> None:
out_dir.mkdir(parents=True, exist_ok=True)
handles = [
open(out_dir / f"shard_{shard:02d}.jsonl", "w", encoding="utf-8")
for shard in range(num_shards)
]
try:
with open(cache_path, "r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
row = json.loads(line)
qid = row["question_id"]
shard = qid_to_shard.get(qid)
if shard is not None:
handles[shard].write(json.dumps(row, ensure_ascii=False) + "\n")
finally:
for handle in handles:
handle.close()
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", required=True)
parser.add_argument("--ret_cache", required=True)
parser.add_argument("--semantic_cache", required=True)
parser.add_argument("--out_dir", required=True)
parser.add_argument("--num_shards", type=int, default=8)
args = parser.parse_args()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
qid_to_shard = split_json_dataset(Path(args.dataset), out_dir, args.num_shards)
split_jsonl_cache(Path(args.ret_cache), out_dir / "ret_cache", qid_to_shard, args.num_shards)
split_jsonl_cache(Path(args.semantic_cache), out_dir / "semantic_cache", qid_to_shard, args.num_shards)
with open(out_dir / "qid_to_shard.json", "w", encoding="utf-8") as f:
json.dump(qid_to_shard, f, ensure_ascii=False, indent=2)
f.write("\n")
print(f"Wrote {args.num_shards} shards to {out_dir}")
if __name__ == "__main__":
main()
|