| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import json |
| from enum import Enum, unique |
| from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union |
|
|
| import fsspec |
| from datasets import DatasetDict, concatenate_datasets, interleave_datasets |
|
|
| from ..extras import logging |
|
|
|
|
| if TYPE_CHECKING: |
| from datasets import Dataset, IterableDataset |
|
|
| from ..hparams import DataArguments |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| SLOTS = list[Union[str, set[str], dict[str, str]]] |
|
|
|
|
| @unique |
| class Role(str, Enum): |
| USER = "user" |
| ASSISTANT = "assistant" |
| SYSTEM = "system" |
| FUNCTION = "function" |
| OBSERVATION = "observation" |
|
|
|
|
| class DatasetModule(TypedDict): |
| train_dataset: Optional[Union["Dataset", "IterableDataset"]] |
| eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]] |
|
|
|
|
| def merge_dataset( |
| all_datasets: list[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int |
| ) -> Union["Dataset", "IterableDataset"]: |
| r"""Merge multiple datasets to a unified dataset.""" |
| if len(all_datasets) == 1: |
| return all_datasets[0] |
|
|
| elif data_args.mix_strategy == "concat": |
| if data_args.streaming: |
| logger.warning_rank0_once("The samples between different datasets will not be mixed in streaming mode.") |
|
|
| return concatenate_datasets(all_datasets) |
|
|
| elif data_args.mix_strategy.startswith("interleave"): |
| if not data_args.streaming: |
| logger.warning_rank0_once("We recommend using `mix_strategy=concat` in non-streaming mode.") |
|
|
| return interleave_datasets( |
| datasets=all_datasets, |
| probabilities=data_args.interleave_probs, |
| seed=seed, |
| stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted", |
| ) |
|
|
| else: |
| raise ValueError(f"Unknown mixing strategy: {data_args.mix_strategy}.") |
|
|
|
|
| def split_dataset( |
| dataset: Optional[Union["Dataset", "IterableDataset"]], |
| eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]], |
| data_args: "DataArguments", |
| seed: int, |
| ) -> "DatasetDict": |
| r"""Split the dataset and returns a dataset dict containing train set and validation set. |
| |
| Support both map dataset and iterable dataset. |
| """ |
| if eval_dataset is not None and data_args.val_size > 1e-6: |
| raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.") |
|
|
| dataset_dict = {} |
| if dataset is not None: |
| if data_args.streaming: |
| dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed) |
|
|
| if data_args.val_size > 1e-6: |
| if data_args.streaming: |
| dataset_dict["validation"] = dataset.take(int(data_args.val_size)) |
| dataset_dict["train"] = dataset.skip(int(data_args.val_size)) |
| else: |
| val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size |
| dataset_dict = dataset.train_test_split(test_size=val_size, seed=seed) |
| dataset = dataset.train_test_split(test_size=val_size, seed=seed) |
| dataset_dict = {"train": dataset["train"], "validation": dataset["test"]} |
| else: |
| dataset_dict["train"] = dataset |
|
|
| if eval_dataset is not None: |
| if isinstance(eval_dataset, dict): |
| dataset_dict.update({f"validation_{name}": data for name, data in eval_dataset.items()}) |
| else: |
| if data_args.streaming: |
| eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed) |
|
|
| dataset_dict["validation"] = eval_dataset |
|
|
| return DatasetDict(dataset_dict) |
|
|
|
|
| def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule": |
| r"""Convert dataset or dataset dict to dataset module.""" |
| dataset_module: DatasetModule = {} |
| if isinstance(dataset, DatasetDict): |
| if "train" in dataset: |
| dataset_module["train_dataset"] = dataset["train"] |
|
|
| if "validation" in dataset: |
| dataset_module["eval_dataset"] = dataset["validation"] |
| else: |
| eval_dataset = {} |
| for key in dataset.keys(): |
| if key.startswith("validation_"): |
| eval_dataset[key[len("validation_") :]] = dataset[key] |
|
|
| if len(eval_dataset): |
| dataset_module["eval_dataset"] = eval_dataset |
|
|
| else: |
| dataset_module["train_dataset"] = dataset |
|
|
| return dataset_module |
|
|
|
|
| def setup_fs(path: str, anon: bool = False) -> "fsspec.AbstractFileSystem": |
| r"""Set up a filesystem object based on the path protocol.""" |
| storage_options = {"anon": anon} if anon else {} |
| if path.startswith("s3://"): |
| fs = fsspec.filesystem("s3", **storage_options) |
| elif path.startswith(("gs://", "gcs://")): |
| fs = fsspec.filesystem("gcs", **storage_options) |
| else: |
| raise ValueError(f"Unsupported protocol in path: {path}. Use 's3://' or 'gs://'.") |
|
|
| if not fs.exists(path): |
| raise ValueError(f"Path does not exist: {path}.") |
|
|
| return fs |
|
|
|
|
| def _read_json_with_fs(fs: "fsspec.AbstractFileSystem", path: str) -> list[Any]: |
| r"""Helper function to read JSON/JSONL files using fsspec.""" |
| with fs.open(path, "r") as f: |
| if path.endswith(".jsonl"): |
| return [json.loads(line) for line in f if line.strip()] |
| else: |
| return json.load(f) |
|
|
|
|
| def read_cloud_json(cloud_path: str) -> list[Any]: |
| r"""Read a JSON/JSONL file from cloud storage (S3 or GCS). |
| |
| Args: |
| cloud_path: str |
| Cloud path in the format: |
| - 's3://bucket-name/file.json' for AWS S3 |
| - 'gs://bucket-name/file.jsonl' or 'gcs://bucket-name/file.jsonl' for Google Cloud Storage |
| """ |
| try: |
| fs = setup_fs(cloud_path, anon=True) |
| except Exception: |
| fs = setup_fs(cloud_path) |
|
|
| |
| files = [x["Key"] for x in fs.listdir(cloud_path)] if fs.isdir(cloud_path) else [cloud_path] |
| files = filter(lambda file: file.endswith(".json") or file.endswith(".jsonl"), files) |
| if not files: |
| raise ValueError(f"No JSON/JSONL files found in the specified path: {cloud_path}.") |
|
|
| return sum([_read_json_with_fs(fs, file) for file in files], []) |
|
|