LeTue09's picture
initial clean commit
1faccd4
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess the AIME 2024 / 2025 datasets to parquet format.
"""
import argparse
import json
import os
import datasets
from verl.utils.hdfs_io import copy, makedirs
INSTRUCTION_FOLLOWING = "Let's think step by step and output the final answer within \\boxed{}."
DATASET_SPECS = {
"aime24": {
"hf_name": "Maxwell-Jia/AIME_2024",
"split": "train",
"question_key": "Problem",
"answer_key": "Answer",
"output_name": "aime-2024.parquet",
"example_name": "aime-2024.example.json",
"data_source": "aime24",
},
"aime25": {
"hf_name": "yentinglin/aime_2025",
"split": "train",
"question_key": "problem",
"answer_key": "solution",
"output_name": "aime-2025.parquet",
"example_name": "aime-2025.example.json",
"data_source": "aime25",
},
}
def build_dataset(dataset_name, local_dataset_path=None):
spec = DATASET_SPECS[dataset_name]
if local_dataset_path is not None:
dataset = datasets.load_dataset(local_dataset_path, split=spec["split"])
else:
dataset = datasets.load_dataset(spec["hf_name"], split=spec["split"])
def process_fn(example, idx):
question = str(example[spec["question_key"]]).strip()
ground_truth = str(example[spec["answer_key"]]).strip()
prompt = f"{question} {INSTRUCTION_FOLLOWING}"
return {
"data_source": spec["data_source"],
"prompt": [{"role": "user", "content": prompt}],
"ability": "math",
"reward_model": {"style": "rule", "ground_truth": ground_truth},
"extra_info": {
"split": "test",
"index": idx,
"question": question,
},
}
return dataset.map(process_fn, with_indices=True, remove_columns=dataset.column_names)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local_dir", default=None)
parser.add_argument("--hdfs_dir", default=None)
parser.add_argument("--datasets", nargs="+", default=["aime24", "aime25"])
parser.add_argument("--aime24_local_dataset_path", default=None)
parser.add_argument("--aime25_local_dataset_path", default=None)
parser.add_argument(
"--local_save_dir",
default="~/data/aime",
help="The save directory for the preprocessed dataset.",
)
args = parser.parse_args()
selected_datasets = []
for dataset_name in args.datasets:
if dataset_name == "all":
selected_datasets.extend(["aime24", "aime25"])
continue
if dataset_name not in DATASET_SPECS:
raise ValueError(f"Unsupported dataset: {dataset_name}")
selected_datasets.append(dataset_name)
selected_datasets = list(dict.fromkeys(selected_datasets))
local_save_dir = args.local_dir
if local_save_dir is not None:
print("Warning: Argument 'local_dir' is deprecated. Please use 'local_save_dir' instead.")
else:
local_save_dir = args.local_save_dir
local_dir = os.path.expanduser(local_save_dir)
os.makedirs(local_dir, exist_ok=True)
dataset_path_overrides = {
"aime24": args.aime24_local_dataset_path,
"aime25": args.aime25_local_dataset_path,
}
for dataset_name in selected_datasets:
spec = DATASET_SPECS[dataset_name]
print(f"Loading the {spec['hf_name']} dataset from huggingface...", flush=True)
dataset = build_dataset(dataset_name, dataset_path_overrides[dataset_name])
dataset.to_parquet(os.path.join(local_dir, spec["output_name"]))
with open(os.path.join(local_dir, spec["example_name"]), "w") as f:
json.dump(dataset[0], f, indent=2)
if args.hdfs_dir is not None:
makedirs(args.hdfs_dir)
copy(src=local_dir, dst=args.hdfs_dir)