File size: 2,744 Bytes
9c60174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#!/usr/bin/env python3
"""Create deterministic dataset/retrieval-cache shards by question id."""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Dict


def split_json_dataset(dataset_path: Path, out_dir: Path, num_shards: int) -> Dict[str, int]:
    data = json.load(open(dataset_path, "r", encoding="utf-8"))
    qid_to_shard: Dict[str, int] = {}
    shards = [[] for _ in range(num_shards)]
    dataset_dir = out_dir / "dataset"
    dataset_dir.mkdir(parents=True, exist_ok=True)

    for idx, entry in enumerate(data):
        shard = idx % num_shards
        qid_to_shard[entry["question_id"]] = shard
        shards[shard].append(entry)

    for shard, rows in enumerate(shards):
        path = dataset_dir / f"shard_{shard:02d}.json"
        with open(path, "w", encoding="utf-8") as f:
            json.dump(rows, f, ensure_ascii=False, indent=2)
            f.write("\n")
    return qid_to_shard


def split_jsonl_cache(cache_path: Path, out_dir: Path, qid_to_shard: Dict[str, int], num_shards: int) -> None:
    out_dir.mkdir(parents=True, exist_ok=True)
    handles = [
        open(out_dir / f"shard_{shard:02d}.jsonl", "w", encoding="utf-8")
        for shard in range(num_shards)
    ]
    try:
        with open(cache_path, "r", encoding="utf-8") as f:
            for line in f:
                if not line.strip():
                    continue
                row = json.loads(line)
                qid = row["question_id"]
                shard = qid_to_shard.get(qid)
                if shard is not None:
                    handles[shard].write(json.dumps(row, ensure_ascii=False) + "\n")
    finally:
        for handle in handles:
            handle.close()


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", required=True)
    parser.add_argument("--ret_cache", required=True)
    parser.add_argument("--semantic_cache", required=True)
    parser.add_argument("--out_dir", required=True)
    parser.add_argument("--num_shards", type=int, default=8)
    args = parser.parse_args()

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    qid_to_shard = split_json_dataset(Path(args.dataset), out_dir, args.num_shards)
    split_jsonl_cache(Path(args.ret_cache), out_dir / "ret_cache", qid_to_shard, args.num_shards)
    split_jsonl_cache(Path(args.semantic_cache), out_dir / "semantic_cache", qid_to_shard, args.num_shards)

    with open(out_dir / "qid_to_shard.json", "w", encoding="utf-8") as f:
        json.dump(qid_to_shard, f, ensure_ascii=False, indent=2)
        f.write("\n")

    print(f"Wrote {args.num_shards} shards to {out_dir}")


if __name__ == "__main__":
    main()