File size: 13,632 Bytes
d522318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
import argparse
import json
import os
import subprocess
from pathlib import Path
from typing import Dict, Tuple

from tqdm import tqdm

from datasets import concatenate_datasets, config, load_dataset

"""
This script will convert the ultrachat/sharegpt dataset to the following schema in jsonl format:
{
    "id": str,
    "conversations": [
        {
            "role": str,
            "content": str
        }
    ],
}
"""

ROLE_MAPPING = {
    "human": "user",
    "gpt": "assistant",
    "chatgpt": "assistant",
    "bing": "assistant",
    "bard": "assistant",
}


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset",
        type=str,
        choices=[
            "ultrachat",
            "sharegpt",
            "eaglechat",
            "perfectblend",
            "perfectblend-llama3.1-8b-instruct",
            "perfectblend-llama3.3-70b-instruct",
            "perfectblend-llama4-scout-instruct",
            "perfectblend-llama4-maverick-instruct",
            "magpie-qwen2.5-pro-1m-v0.1",
            "sharegpt4v",
            "allava4v",
            "opc",
        ],
        help="The demo dataset to quickly run the training for speculative decoding",
    )
    parser.add_argument(
        "--output-path",
        type=str,
        default=None,
        help="The path to save the processed dataset, if not specified, the dataset will be saved in the cache/dataset/dataset_name directory of the root path",
    )
    parser.add_argument(
        "--data-path",
        type=str,
        default=None,
        help="The path to the custom dataset, if not specified, the default dataset will be loaded",
    )
    parser.add_argument(
        "--sample-size",
        type=int,
        default=None,
        help="The number of samples to process from the dataset, if not specified, all samples will be processed",
    )
    parser.add_argument(
        "--split-eval",
        action="store_true",
        help="Whether to split the dataset into train and eval sets, default is False",
    )
    parser.add_argument(
        "--opc-subset",
        type=str,
        default="largescale_diverse_instruct",
        choices=[
            "largescale_diverse_instruct",
            "filtered_infinity_instruct",
            "realuser_instruct",
            "all",
        ],
        help="The subset of OpenCoder opc-sft-stage1 dataset to use, or 'all' to use all subsets (default: largescale_diverse_instruct)",
    )
    return parser.parse_args()


def get_cache_dir(dataset_name):
    cache_dir = None
    if dataset_name == "sharegpt4v":
        raise ValueError("Downloading 'sharegpt4v' is not supported.")
    elif dataset_name == "allava4v":
        cache_dir = os.path.join(
            config.HF_DATASETS_CACHE, "FreedomIntelligence", "ALLaVA"
        )
    else:
        raise ValueError(
            f"Dataset '{dataset_name}' is not a supported VLM dataset for download."
        )
    return cache_dir


def download_vlm_dataset(dataset_name: str) -> None:
    """Download VLM's dataset such as sharegpt4v and allava4v"""
    if dataset_name == "sharegpt4v":
        raise Exception("Don't Support Download sharegpt4v.")
    elif dataset_name == "allava4v":
        cache_dir = get_cache_dir(dataset_name)
        os.makedirs(cache_dir, exist_ok=True)
        script_path = os.path.join(
            os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
            "datasets",
            "download_laion.sh",
        )
        os.chmod(script_path, 0o755)
        if not os.path.exists(
            os.path.join(cache_dir, "allava_laion", "image_chunks", "images_0.zip")
        ):
            result = subprocess.run(
                ["bash", script_path],
                cwd=cache_dir,
                capture_output=True,
                text=True,
            )
            if result.returncode != 0:
                raise RuntimeError(f"Download image dataset failed: {result.stderr}")
            print("##### allava4v dataset Download Complete #####")
        else:
            print("##### allava4v dataset has existed.")
    else:
        raise Exception(f"Don't support {dataset_name}")


def process_ultrachat_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, int]:
    """Process a row from the ultrachat dataset.

    The function expects a row with the following schema:
    "messages": [
        {
            "role": "user" | "assistant",
            "content": str
        }
    ]
    """
    conversations = row["messages"]
    formatted_conversations = []
    for message in conversations:
        role = message["role"]
        content = message["content"]
        assert role in ["user", "assistant"]
        formatted_conversations.append({"role": role, "content": content})
    row = {"id": row["prompt_id"], "conversations": formatted_conversations}
    return row, 0


def process_sharegpt_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, int]:
    """
    sharegpt dataset schema:
    {
        "conversations": [
            {
                "from": <system|human|gpt>,
                "value": <message>,
            },
            ...
        ]
    }
    """
    conversations = row["conversations"]
    formatted_conversations = []
    skipped_count = 0
    for message in conversations:
        if message["from"] not in ROLE_MAPPING:
            skipped_count += 1
            continue
        new_role = ROLE_MAPPING[message["from"]]
        content = message["value"]
        formatted_conversations.append({"role": new_role, "content": content})

    row = {"id": row["id"], "conversations": formatted_conversations}
    return row, skipped_count


def process_sharegpt4v_row(row, dataset_name: str = None) -> Dict:
    """
    sharegpt4v dataset schema:
    {
        "id": str,
        "image": str,  # path to the image
        "conversations": [
            {
                "from": <human|gpt>,
                "value": <message>,
            },
            ...
        ]
    }
    """
    cache_dir = get_cache_dir(dataset_name)
    conversations = row["conversations"]
    image = os.path.join(cache_dir, row["image"])
    if not os.path.exists(image):
        print(f"Image path {image} does not exist, skipping this sample.")
        return None, None
    formatted_conversations = []
    skipped_count = 0
    for message in conversations:
        if message["from"] not in ROLE_MAPPING:
            skipped_count += 1
            continue
        new_role = ROLE_MAPPING[message["from"]]
        if new_role == "user":
            text_content = message["value"].replace("<image>\n", "")
            content = text_content
        else:
            content = message["value"]
        formatted_conversations.append({"role": new_role, "content": content})

    row = {"id": row["id"], "image": image, "conversations": formatted_conversations}
    return row, skipped_count


def load_dataset_from_path(data_path: Path):
    suffix = data_path.suffix.split(".")[1]
    ds = load_dataset(suffix, data_files=str(data_path), split="train")
    return ds


def process_and_save_ds(train_ds, test_ds, output_path, proc_fn, dataset_name):
    train_output_jsonl_path = output_path.joinpath(f"{dataset_name}_train.jsonl")
    if train_output_jsonl_path.exists():
        print(
            f"The dataset {dataset_name} has already been processed and saved in {train_output_jsonl_path}, skipping..."
        )
        return

    total_skipped_count = 0
    with open(train_output_jsonl_path, "w") as f:
        for item in tqdm(train_ds, desc=f"Processing {dataset_name} dataset"):
            if proc_fn is not None:
                row, skipped_count = proc_fn(item, dataset_name)
                if row is None:
                    continue
                total_skipped_count += skipped_count
            else:
                row = item
            f.write(json.dumps(row, ensure_ascii=False) + "\n")

    if test_ds is not None:
        test_output_jsonl_path = output_path.joinpath(f"{dataset_name}_test.jsonl")
        with open(test_output_jsonl_path, "w") as f:
            for item in tqdm(test_ds, desc=f"Processing {dataset_name} test dataset"):
                if proc_fn is not None:
                    row, skipped_count = proc_fn(item, dataset_name)
                    if row is None:
                        continue
                    total_skipped_count += skipped_count
                else:
                    row = item
                f.write(json.dumps(row, ensure_ascii=False) + "\n")

    if total_skipped_count > 0:
        total_messages = len(train_ds) + (len(test_ds) if test_ds is not None else 0)
        print(
            f"Skipped {total_skipped_count}/{total_messages} messages for {dataset_name}"
        )


import hashlib


def process_opc_sft_stage1(row: Dict) -> Tuple[Dict, int]:
    row_id = hashlib.md5((row["instruction"] + row["output"]).encode()).hexdigest()
    processed_row = {
        "id": row_id,
        "conversations": [
            {"role": "user", "content": row["instruction"]},
            {"role": "assistant", "content": row["output"]},
        ],
    }
    return processed_row, 0


def add_index(row, idx) -> Dict:
    row["id"] = idx
    return row


def main():
    args = parse_args()
    # load dataset
    if args.dataset == "ultrachat":
        ds = load_dataset("HuggingFaceH4/ultrachat_200k")["train_sft"]
        proc_fn = process_ultrachat_row
    elif args.dataset == "sharegpt":
        if args.data_path is None:
            ds = load_dataset("Aeala/ShareGPT_Vicuna_unfiltered")["train"]
        else:
            print("Loading dataset from custom data path: ", args.data_path)
            ds = load_dataset_from_path(Path(args.data_path))
        proc_fn = process_sharegpt_row
    elif args.dataset == "eaglechat":
        ds = load_dataset("zhaode/EagleChat")["train"]
        proc_fn = lambda row: (row, 0)
    elif args.dataset == "perfectblend":
        ds = load_dataset("mlabonne/open-perfectblend")["train"]
        ds = ds.map(add_index, with_indices=True)
        proc_fn = process_sharegpt_row
    elif args.dataset == "perfectblend-llama3.1-8b-instruct":
        ds = load_dataset("frankleeeee/PerfectBlend-Regenerated-Llama-3.1-8B-Instruct")[
            "train"
        ]
        ds = ds.map(add_index, with_indices=True)
        proc_fn = None
    elif args.dataset == "perfectblend-llama3.3-70b-instruct":
        ds = load_dataset(
            "frankleeeee/PerfectBlend-Regenerated-Llama-3.3-70B-Instruct"
        )["train"]
        ds = ds.map(add_index, with_indices=True)
        proc_fn = None
    elif args.dataset == "perfectblend-llama4-scout-instruct":
        ds = load_dataset(
            "frankleeeee/PerfectBlend-Regenerated-Llama-4-Scout-17B-16E-Instruct"
        )["train"]
        ds = ds.map(add_index, with_indices=True)
        proc_fn = None
    elif args.dataset == "perfectblend-llama4-maverick-instruct":
        ds = load_dataset(
            "frankleeeee/PerfectBlend-Regenerated-Llama-4-Maverick-17B-128E-Instruct"
        )["train"]
        ds = ds.map(add_index, with_indices=True)
        proc_fn = None
    elif args.dataset == "magpie-qwen2.5-pro-1m-v0.1":
        ds = load_dataset("Magpie-Align/Magpie-Qwen2.5-Pro-1M-v0.1")["train"]
        ds = ds.rename_column("uuid", "id")
        proc_fn = process_sharegpt_row
    elif args.dataset == "sharegpt4v":
        ds = load_dataset("Lin-Chen/ShareGPT4V", "ShareGPT4V")["train"]
        raise Exception("Not supported sharegpt4v now")
        download_vlm_dataset(args.dataset)
        proc_fn = process_sharegpt4v_row
    elif args.dataset == "allava4v":
        ds = load_dataset("FreedomIntelligence/ALLaVA-4V", name="allava_laion")[
            "instruct"
        ]
        download_vlm_dataset(args.dataset)
        proc_fn = process_sharegpt4v_row
    elif args.dataset == "opc":
        if args.opc_subset == "all":
            # Load all subsets and concatenate them
            subsets = [
                "largescale_diverse_instruct",
                "filtered_infinity_instruct",
                "realuser_instruct",
            ]
            datasets_list = [
                load_dataset("OpenCoder-LLM/opc-sft-stage1", subset)["train"]
                for subset in subsets
            ]
            ds = concatenate_datasets(datasets_list)
        else:
            ds = load_dataset("OpenCoder-LLM/opc-sft-stage1", args.opc_subset)["train"]
        proc_fn = process_opc_sft_stage1
    else:
        raise ValueError(
            f"This script only supports ultrachat, sharegpt, sharegpt4v, allava4v, opc, and perfect-blend-gptoss-20B datasets for demo purpose, if you wish to use other datasets, please modify this script."
        )
    # filter and split dataset
    if args.sample_size is not None and args.sample_size < len(ds):
        ds = ds.select(range(args.sample_size))
        print(f"Processing {args.sample_size} samples from the dataset {args.dataset}")
    if args.split_eval:
        ds = ds.train_test_split(test_size=0.05)
        train_ds = ds["train"]
        test_ds = ds["test"]
    else:
        train_ds = ds
        test_ds = None

    if args.output_path is None:
        root_path = Path(__file__).parent.parent
        output_path = root_path.joinpath("cache", "dataset")
        output_path.mkdir(parents=True, exist_ok=True)
    else:
        output_path = Path(args.output_path)
        output_path.mkdir(parents=True, exist_ok=True)

    process_and_save_ds(train_ds, test_ds, output_path, proc_fn, args.dataset)


if __name__ == "__main__":
    main()