| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import json |
| | import os |
| | from dataclasses import dataclass |
| | from typing import Any, Dict, List, Literal, Optional, Sequence |
| |
|
| | from transformers.utils import cached_file |
| |
|
| | from ..extras.constants import DATA_CONFIG |
| | from ..extras.misc import use_modelscope, use_openmind |
| |
|
| |
|
| | @dataclass |
| | class DatasetAttr: |
| | r""" |
| | Dataset attributes. |
| | """ |
| |
|
| | |
| | load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"] |
| | dataset_name: str |
| | formatting: Literal["alpaca", "sharegpt"] = "alpaca" |
| | ranking: bool = False |
| | |
| | subset: Optional[str] = None |
| | split: str = "train" |
| | folder: Optional[str] = None |
| | num_samples: Optional[int] = None |
| | |
| | system: Optional[str] = None |
| | tools: Optional[str] = None |
| | images: Optional[str] = None |
| | videos: Optional[str] = None |
| | |
| | chosen: Optional[str] = None |
| | rejected: Optional[str] = None |
| | kto_tag: Optional[str] = None |
| | |
| | prompt: Optional[str] = "instruction" |
| | query: Optional[str] = "input" |
| | response: Optional[str] = "output" |
| | history: Optional[str] = None |
| | |
| | messages: Optional[str] = "conversations" |
| | |
| | role_tag: Optional[str] = "from" |
| | content_tag: Optional[str] = "value" |
| | user_tag: Optional[str] = "human" |
| | assistant_tag: Optional[str] = "gpt" |
| | observation_tag: Optional[str] = "observation" |
| | function_tag: Optional[str] = "function_call" |
| | system_tag: Optional[str] = "system" |
| |
|
| | def __repr__(self) -> str: |
| | return self.dataset_name |
| |
|
| | def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None: |
| | setattr(self, key, obj.get(key, default)) |
| |
|
| |
|
| | def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]: |
| | r""" |
| | Gets the attributes of the datasets. |
| | """ |
| | if dataset_names is None: |
| | dataset_names = [] |
| |
|
| | if dataset_dir == "ONLINE": |
| | dataset_info = None |
| | else: |
| | if dataset_dir.startswith("REMOTE:"): |
| | config_path = cached_file(path_or_repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset") |
| | else: |
| | config_path = os.path.join(dataset_dir, DATA_CONFIG) |
| |
|
| | try: |
| | with open(config_path, "r") as f: |
| | dataset_info = json.load(f) |
| | except Exception as err: |
| | if len(dataset_names) != 0: |
| | raise ValueError("Cannot open {} due to {}.".format(config_path, str(err))) |
| |
|
| | dataset_info = None |
| |
|
| | dataset_list: List["DatasetAttr"] = [] |
| | for name in dataset_names: |
| | if dataset_info is None: |
| | if use_modelscope(): |
| | load_from = "ms_hub" |
| | elif use_openmind(): |
| | load_from = "om_hub" |
| | else: |
| | load_from = "hf_hub" |
| | dataset_attr = DatasetAttr(load_from, dataset_name=name) |
| | dataset_list.append(dataset_attr) |
| | continue |
| |
|
| | if name not in dataset_info: |
| | raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG)) |
| |
|
| | has_hf_url = "hf_hub_url" in dataset_info[name] |
| | has_ms_url = "ms_hub_url" in dataset_info[name] |
| | has_om_url = "om_hub_url" in dataset_info[name] |
| |
|
| | if has_hf_url or has_ms_url or has_om_url: |
| | if has_ms_url and (use_modelscope() or not has_hf_url): |
| | dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"]) |
| | elif has_om_url and (use_openmind() or not has_hf_url): |
| | dataset_attr = DatasetAttr("om_hub", dataset_name=dataset_info[name]["om_hub_url"]) |
| | else: |
| | dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) |
| | elif "script_url" in dataset_info[name]: |
| | dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) |
| | else: |
| | dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"]) |
| |
|
| | dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") |
| | dataset_attr.set_attr("ranking", dataset_info[name], default=False) |
| | dataset_attr.set_attr("subset", dataset_info[name]) |
| | dataset_attr.set_attr("split", dataset_info[name], default="train") |
| | dataset_attr.set_attr("folder", dataset_info[name]) |
| | dataset_attr.set_attr("num_samples", dataset_info[name]) |
| |
|
| | if "columns" in dataset_info[name]: |
| | column_names = ["system", "tools", "images", "videos", "chosen", "rejected", "kto_tag"] |
| | if dataset_attr.formatting == "alpaca": |
| | column_names.extend(["prompt", "query", "response", "history"]) |
| | else: |
| | column_names.extend(["messages"]) |
| |
|
| | for column_name in column_names: |
| | dataset_attr.set_attr(column_name, dataset_info[name]["columns"]) |
| |
|
| | if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]: |
| | tag_names = ( |
| | "role_tag", |
| | "content_tag", |
| | "user_tag", |
| | "assistant_tag", |
| | "observation_tag", |
| | "function_tag", |
| | "system_tag", |
| | ) |
| | for tag in tag_names: |
| | dataset_attr.set_attr(tag, dataset_info[name]["tags"]) |
| |
|
| | dataset_list.append(dataset_attr) |
| |
|
| | return dataset_list |
| |
|