| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Fine-tuning the library models for causal language modeling using |
| Fill-in-the middle (FIM) objective on a text file or a dataset. |
| |
| Here is the full list of checkpoints on the hub that can be fine-tuned by this script: |
| https://huggingface.co/models?filter=text-generation |
| """ |
| |
|
|
| import logging |
| import math |
| import os |
| import sys |
| from dataclasses import dataclass, field |
| from itertools import chain |
| from typing import Optional |
|
|
| import datasets |
| import evaluate |
| import numpy as np |
| import torch |
| from datasets import load_dataset |
|
|
| import transformers |
| from transformers import ( |
| CONFIG_MAPPING, |
| MODEL_FOR_CAUSAL_LM_MAPPING, |
| AutoConfig, |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| HfArgumentParser, |
| Trainer, |
| TrainingArguments, |
| default_data_collator, |
| is_torch_xla_available, |
| set_seed, |
| ) |
| from transformers.integrations import is_deepspeed_zero3_enabled |
| from transformers.testing_utils import CaptureLogger |
| from transformers.trainer_utils import get_last_checkpoint |
| from transformers.utils import check_min_version, send_example_telemetry |
| from transformers.utils.versions import require_version |
|
|
|
|
| |
| check_min_version("4.52.0.dev0") |
|
|
| require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) |
| MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) |
|
|
|
|
| @dataclass |
| class ModelArguments: |
| """ |
| Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. |
| """ |
|
|
| model_name_or_path: Optional[str] = field( |
| default=None, |
| metadata={ |
| "help": ( |
| "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch." |
| ) |
| }, |
| ) |
| model_type: Optional[str] = field( |
| default=None, |
| metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, |
| ) |
| config_overrides: Optional[str] = field( |
| default=None, |
| metadata={ |
| "help": ( |
| "Override some existing default config settings when a model is trained from scratch. Example: " |
| "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" |
| ) |
| }, |
| ) |
| config_name: Optional[str] = field( |
| default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} |
| ) |
| tokenizer_name: Optional[str] = field( |
| default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} |
| ) |
| cache_dir: Optional[str] = field( |
| default=None, |
| metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, |
| ) |
| use_fast_tokenizer: bool = field( |
| default=True, |
| metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, |
| ) |
| model_revision: str = field( |
| default="main", |
| metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, |
| ) |
| token: str = field( |
| default=None, |
| metadata={ |
| "help": ( |
| "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token " |
| "generated when running `huggingface-cli login` (stored in `~/.huggingface`)." |
| ) |
| }, |
| ) |
| trust_remote_code: bool = field( |
| default=False, |
| metadata={ |
| "help": ( |
| "Whether to trust the execution of code from datasets/models defined on the Hub." |
| " This option should only be set to `True` for repositories you trust and in which you have read the" |
| " code, as it will execute code present on the Hub on your local machine." |
| ) |
| }, |
| ) |
| torch_dtype: Optional[str] = field( |
| default=None, |
| metadata={ |
| "help": ( |
| "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " |
| "dtype will be automatically derived from the model's weights." |
| ), |
| "choices": ["auto", "bfloat16", "float16", "float32"], |
| }, |
| ) |
| low_cpu_mem_usage: bool = field( |
| default=False, |
| metadata={ |
| "help": ( |
| "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. " |
| "set True will benefit LLM loading time and RAM consumption." |
| ) |
| }, |
| ) |
| pad_to_multiple_of: bool = field( |
| default=False, |
| metadata={ |
| "help": ( |
| "Whether to pad the embedding layer to a multiple depending on the device. ", |
| "For NVIDIA GPUs, this will be a multiple of 8, for TPUs a multiple of 128.", |
| ) |
| }, |
| ) |
| attn_implementation: Optional[str] = field( |
| default="sdpa", metadata={"help": ("The attention implementation to use. ")} |
| ) |
|
|
| def __post_init__(self): |
| if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): |
| raise ValueError( |
| "--config_overrides can't be used in combination with --config_name or --model_name_or_path" |
| ) |
|
|
|
|
| @dataclass |
| class DataTrainingArguments: |
| """ |
| Arguments pertaining to what data we are going to input our model for training and eval. |
| """ |
|
|
| dataset_name: Optional[str] = field( |
| default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} |
| ) |
| dataset_config_name: Optional[str] = field( |
| default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} |
| ) |
| train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) |
| validation_file: Optional[str] = field( |
| default=None, |
| metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, |
| ) |
| max_train_samples: Optional[int] = field( |
| default=None, |
| metadata={ |
| "help": ( |
| "For debugging purposes or quicker training, truncate the number of training examples to this " |
| "value if set." |
| ) |
| }, |
| ) |
| max_eval_samples: Optional[int] = field( |
| default=None, |
| metadata={ |
| "help": ( |
| "For debugging purposes or quicker training, truncate the number of evaluation examples to this " |
| "value if set." |
| ) |
| }, |
| ) |
| streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"}) |
| block_size: Optional[int] = field( |
| default=None, |
| metadata={ |
| "help": ( |
| "Optional input sequence length after tokenization. " |
| "The training dataset will be truncated in block of this size for training. " |
| "Default to the model max input length for single sentence inputs (take into account special tokens)." |
| ) |
| }, |
| ) |
| fim_rate: Optional[float] = field( |
| default=0.5, |
| metadata={ |
| "help": ( |
| "Optional probability with which the FIM transformation is applied to the example. " |
| "Default is 0.5. A rate of 1.0 means every example will undergo FIM transformation, " |
| "while a rate of 0.0 means no example will." |
| ) |
| }, |
| ) |
| fim_spm_rate: Optional[float] = field( |
| default=0.5, |
| metadata={ |
| "help": ( |
| "Within the examples undergoing FIM transformation, this rate determines the probability " |
| "of applying the Sentence Permutation Mode (SPM). " |
| "Default is 0.5. A rate of 1.0 means all FIM transformations will use SPM, " |
| "while a rate of 0.0 means none will." |
| ) |
| }, |
| ) |
| truncate_or_pad: Optional[bool] = field( |
| default=True, |
| metadata={ |
| "help": ( |
| "Indicates whether the transformed example should be truncated or padded to maintain " |
| "the same length as the original example. " |
| "Default is True. If False, the function will not truncate or pad the examples." |
| ) |
| }, |
| ) |
| fim_prefix_token: Optional[str] = field( |
| default="<fim_prefix>", |
| metadata={"help": ("Fill-in-Middle Prefix token. Defaults to '<fim_prefix>'.")}, |
| ) |
| fim_middle_token: Optional[str] = field( |
| default="<fim_middle>", |
| metadata={"help": ("Fill-in-Middle Middle token. Defaults to '<fim_middle>'.")}, |
| ) |
| fim_suffix_token: Optional[str] = field( |
| default="<fim_suffix>", |
| metadata={"help": ("Fill-in-Middle Suffix token. Defaults to '<fim_suffix>'.")}, |
| ) |
| pad_token: Optional[str] = field( |
| default="<fim_pad>", |
| metadata={ |
| "help": ( |
| "Fill-in-Middle Pad token. Used only when 'truncate_or_pad' is set to True. Defaults to '<fim_pad>'." |
| ) |
| }, |
| ) |
| overwrite_cache: bool = field( |
| default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} |
| ) |
| validation_split_percentage: Optional[int] = field( |
| default=5, |
| metadata={ |
| "help": "The percentage of the train set used as validation set in case there's no validation split" |
| }, |
| ) |
| preprocessing_num_workers: Optional[int] = field( |
| default=None, |
| metadata={"help": "The number of processes to use for the preprocessing."}, |
| ) |
| keep_linebreaks: bool = field( |
| default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."} |
| ) |
|
|
| def __post_init__(self): |
| if self.streaming: |
| require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`") |
|
|
| if self.dataset_name is None and self.train_file is None and self.validation_file is None: |
| raise ValueError("Need either a dataset name or a training/validation file.") |
| else: |
| if self.train_file is not None: |
| extension = self.train_file.split(".")[-1] |
| assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." |
| if self.validation_file is not None: |
| extension = self.validation_file.split(".")[-1] |
| assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." |
|
|
|
|
| def main(): |
| |
| |
| |
|
|
| parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) |
| if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): |
| |
| |
| model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) |
| else: |
| model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
|
|
| |
| |
| send_example_telemetry("run_fim", model_args, data_args) |
|
|
| |
| logging.basicConfig( |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| datefmt="%m/%d/%Y %H:%M:%S", |
| handlers=[logging.StreamHandler(sys.stdout)], |
| ) |
|
|
| if training_args.should_log: |
| |
| transformers.utils.logging.set_verbosity_info() |
|
|
| log_level = training_args.get_process_log_level() |
| logger.setLevel(log_level) |
| datasets.utils.logging.set_verbosity(log_level) |
| transformers.utils.logging.set_verbosity(log_level) |
| transformers.utils.logging.enable_default_handler() |
| transformers.utils.logging.enable_explicit_format() |
|
|
| |
| logger.warning( |
| f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " |
| + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" |
| ) |
| logger.info(f"Training/evaluation parameters {training_args}") |
|
|
| |
| last_checkpoint = None |
| if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: |
| last_checkpoint = get_last_checkpoint(training_args.output_dir) |
| if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: |
| raise ValueError( |
| f"Output directory ({training_args.output_dir}) already exists and is not empty. " |
| "Use --overwrite_output_dir to overcome." |
| ) |
| elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: |
| logger.info( |
| f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " |
| "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." |
| ) |
|
|
| |
| set_seed(training_args.seed) |
|
|
| |
| np_rng = np.random.RandomState(seed=training_args.seed) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if data_args.dataset_name is not None: |
| |
| raw_datasets = load_dataset( |
| data_args.dataset_name, |
| data_args.dataset_config_name, |
| cache_dir=model_args.cache_dir, |
| token=model_args.token, |
| streaming=data_args.streaming, |
| trust_remote_code=model_args.trust_remote_code, |
| ) |
| if "validation" not in raw_datasets.keys(): |
| raw_datasets["validation"] = load_dataset( |
| data_args.dataset_name, |
| data_args.dataset_config_name, |
| split=f"train[:{data_args.validation_split_percentage}%]", |
| cache_dir=model_args.cache_dir, |
| token=model_args.token, |
| streaming=data_args.streaming, |
| trust_remote_code=model_args.trust_remote_code, |
| ) |
| raw_datasets["train"] = load_dataset( |
| data_args.dataset_name, |
| data_args.dataset_config_name, |
| split=f"train[{data_args.validation_split_percentage}%:]", |
| cache_dir=model_args.cache_dir, |
| token=model_args.token, |
| streaming=data_args.streaming, |
| trust_remote_code=model_args.trust_remote_code, |
| ) |
| else: |
| data_files = {} |
| dataset_args = {} |
| if data_args.train_file is not None: |
| data_files["train"] = data_args.train_file |
| if data_args.validation_file is not None: |
| data_files["validation"] = data_args.validation_file |
| extension = ( |
| data_args.train_file.split(".")[-1] |
| if data_args.train_file is not None |
| else data_args.validation_file.split(".")[-1] |
| ) |
| if extension == "txt": |
| extension = "text" |
| dataset_args["keep_linebreaks"] = data_args.keep_linebreaks |
| raw_datasets = load_dataset( |
| extension, |
| data_files=data_files, |
| cache_dir=model_args.cache_dir, |
| token=model_args.token, |
| **dataset_args, |
| ) |
| |
| if "validation" not in raw_datasets.keys(): |
| raw_datasets["validation"] = load_dataset( |
| extension, |
| data_files=data_files, |
| split=f"train[:{data_args.validation_split_percentage}%]", |
| cache_dir=model_args.cache_dir, |
| token=model_args.token, |
| **dataset_args, |
| ) |
| raw_datasets["train"] = load_dataset( |
| extension, |
| data_files=data_files, |
| split=f"train[{data_args.validation_split_percentage}%:]", |
| cache_dir=model_args.cache_dir, |
| token=model_args.token, |
| **dataset_args, |
| ) |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| config_kwargs = { |
| "cache_dir": model_args.cache_dir, |
| "revision": model_args.model_revision, |
| "token": model_args.token, |
| "trust_remote_code": model_args.trust_remote_code, |
| } |
| if model_args.config_name: |
| config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) |
| elif model_args.model_name_or_path: |
| config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) |
| else: |
| config = CONFIG_MAPPING[model_args.model_type]() |
| logger.warning("You are instantiating a new config instance from scratch.") |
| if model_args.config_overrides is not None: |
| logger.info(f"Overriding config: {model_args.config_overrides}") |
| config.update_from_string(model_args.config_overrides) |
| logger.info(f"New config: {config}") |
|
|
| tokenizer_kwargs = { |
| "cache_dir": model_args.cache_dir, |
| "use_fast": model_args.use_fast_tokenizer, |
| "revision": model_args.model_revision, |
| "token": model_args.token, |
| "trust_remote_code": model_args.trust_remote_code, |
| } |
| if model_args.tokenizer_name: |
| tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs) |
| elif model_args.model_name_or_path: |
| tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs) |
| else: |
| raise ValueError( |
| "You are instantiating a new tokenizer from scratch. This is not supported by this script. " |
| "You can do it from another script, save it, and load it from here, using --tokenizer_name." |
| ) |
|
|
| if model_args.model_name_or_path: |
| torch_dtype = ( |
| model_args.torch_dtype |
| if model_args.torch_dtype in ["auto", None] |
| else getattr(torch, model_args.torch_dtype) |
| ) |
| model = AutoModelForCausalLM.from_pretrained( |
| model_args.model_name_or_path, |
| from_tf=bool(".ckpt" in model_args.model_name_or_path), |
| config=config, |
| cache_dir=model_args.cache_dir, |
| revision=model_args.model_revision, |
| token=model_args.token, |
| trust_remote_code=model_args.trust_remote_code, |
| torch_dtype=torch_dtype, |
| low_cpu_mem_usage=model_args.low_cpu_mem_usage, |
| attn_implementation=model_args.attn_implementation, |
| ) |
|
|
| else: |
| model = AutoModelForCausalLM.from_config( |
| config, |
| trust_remote_code=model_args.trust_remote_code, |
| attn_implementation=model_args.attn_implementation, |
| ) |
| n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values()) |
| logger.info(f"Training new model from scratch - Total size={n_params / 2**20:.2f}M params") |
|
|
| |
| special_tokens = [data_args.fim_prefix_token, data_args.fim_middle_token, data_args.fim_suffix_token] |
| if data_args.truncate_or_pad: |
| special_tokens.append(data_args.pad_token) |
|
|
| |
| pad_factor = 1 |
| if torch.cuda.is_availble(): |
| pad_factor = 8 |
|
|
| elif is_torch_xla_available(check_is_tpu=True): |
| pad_factor = 128 |
|
|
| |
| tokenizer.add_tokens(special_tokens) |
| original_embeddings = model.get_input_embeddings() |
|
|
| if is_deepspeed_zero3_enabled(): |
| import deepspeed |
|
|
| with deepspeed.zero.GatheredParameters(original_embeddings.weight, modifier_rank=0): |
| |
| model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=pad_factor) |
| embeddings = model.get_input_embeddings() |
|
|
| |
| |
| |
| mean = original_embeddings.mean(dim=0) |
| n = original_embeddings.size()[0] |
| sigma = ((original_embeddings - mean).T @ (original_embeddings - mean)) / n |
| dist = torch.distributions.multivariate_normal.MultivariateNormal( |
| mean, |
| covariance_matrix=1e-5 * sigma, |
| ) |
| new_token_embeddings = torch.stack( |
| tuple(dist.sample() for _ in range(len(special_tokens))), |
| dim=0, |
| ) |
| else: |
| original_embeddings = model.get_input_embeddings() |
| |
| model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=pad_factor) |
| embeddings = model.get_input_embeddings() |
|
|
| |
| |
| |
| mean = original_embeddings.mean(dim=0) |
| n = original_embeddings.size()[0] |
| sigma = ((original_embeddings - mean).T @ (original_embeddings - mean)) / n |
| dist = torch.distributions.multivariate_normal.MultivariateNormal( |
| mean, |
| covariance_matrix=1e-5 * sigma, |
| ) |
| new_token_embeddings = torch.stack( |
| tuple(dist.sample() for _ in range(len(special_tokens))), |
| dim=0, |
| ) |
|
|
| if is_deepspeed_zero3_enabled(): |
| import deepspeed |
|
|
| with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=0): |
| |
| embeddings.weight.data[-len(special_tokens) :] = new_token_embeddings |
| else: |
| |
| embeddings.weight.data[-len(special_tokens) :] = new_token_embeddings |
|
|
| |
| model.set_input_embeddings(embeddings) |
|
|
| logger.info("Added special tokens to the tokenizer and resized model's embedding layer") |
|
|
| |
| |
| if training_args.do_train: |
| column_names = list(raw_datasets["train"].features) |
| else: |
| column_names = list(raw_datasets["validation"].features) |
| text_column_name = "text" if "text" in column_names else column_names[0] |
|
|
| |
| tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base") |
|
|
| def tokenize_function(examples): |
| with CaptureLogger(tok_logger) as cl: |
| output = tokenizer(examples[text_column_name]) |
| |
| if "Token indices sequence length is longer than the" in cl.out: |
| tok_logger.warning( |
| "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits" |
| " before being passed to the model." |
| ) |
| return output |
|
|
| with training_args.main_process_first(desc="dataset map tokenization"): |
| if not data_args.streaming: |
| tokenized_datasets = raw_datasets.map( |
| tokenize_function, |
| batched=True, |
| num_proc=data_args.preprocessing_num_workers, |
| remove_columns=column_names, |
| load_from_cache_file=not data_args.overwrite_cache, |
| desc="Running tokenizer on dataset", |
| ) |
| else: |
| tokenized_datasets = raw_datasets.map( |
| tokenize_function, |
| batched=True, |
| remove_columns=column_names, |
| ) |
|
|
| if data_args.block_size is None: |
| block_size = tokenizer.model_max_length |
| if block_size > config.max_position_embeddings: |
| logger.warning( |
| f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " |
| f"Using block_size={min(1024, config.max_position_embeddings)} instead. You can change that default value by passing --block_size xxx." |
| ) |
| block_size = min(1024, config.max_position_embeddings) |
| else: |
| if data_args.block_size > tokenizer.model_max_length: |
| logger.warning( |
| f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model " |
| f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." |
| ) |
| block_size = min(data_args.block_size, tokenizer.model_max_length) |
|
|
| |
| def group_texts(examples): |
| |
| concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} |
| total_length = len(concatenated_examples[list(examples.keys())[0]]) |
| |
| |
| total_length = (total_length // block_size) * block_size |
| |
| result = { |
| k: [t[i : i + block_size] for i in range(0, total_length, block_size)] |
| for k, t in concatenated_examples.items() |
| } |
| result["labels"] = result["input_ids"].copy() |
| return result |
|
|
| |
| prefix_tok_id = tokenizer.convert_tokens_to_ids(data_args.fim_prefix_token) |
| middle_tok_id = tokenizer.convert_tokens_to_ids(data_args.fim_middle_token) |
| suffix_tok_id = tokenizer.convert_tokens_to_ids(data_args.fim_suffix_token) |
| pad_tok_id = None |
|
|
| |
| if data_args.truncate_or_pad: |
| pad_tok_id = tokenizer.convert_tokens_to_ids(data_args.pad_token) |
|
|
| |
| |
| |
| def fim_transform(example): |
| """ |
| This function performs FIM transformation on a single example (list of tokens) |
| """ |
| if np_rng.binomial(1, data_args.fim_rate): |
| boundaries = sorted(np_rng.randint(low=0, high=len(example) + 1, size=2)) |
|
|
| prefix = example[: boundaries[0]] |
| middle = example[boundaries[0] : boundaries[1]] |
| suffix = example[boundaries[1] :] |
|
|
| if data_args.truncate_or_pad: |
| total_length = len(prefix) + len(middle) + len(suffix) + 3 |
| diff = total_length - len(example) |
| if diff > 0: |
| suffix = suffix[: max(0, len(suffix) - diff)] |
| elif diff < 0: |
| suffix.extend([pad_tok_id] * (-diff)) |
|
|
| if np_rng.binomial(1, data_args.fim_spm_rate): |
| |
| transformed_example = [prefix_tok_id, suffix_tok_id] + suffix + [middle_tok_id] + prefix + middle |
| else: |
| |
| transformed_example = [prefix_tok_id] + prefix + [suffix_tok_id] + suffix + [middle_tok_id] + middle |
| else: |
| transformed_example = example |
|
|
| return transformed_example |
|
|
| |
| def apply_fim(examples): |
| """ |
| Apply FIM transformation to a batch of examples |
| """ |
| fim_transform_ids = [fim_transform(ids) for ids in examples["input_ids"]] |
| examples["input_ids"] = fim_transform_ids |
| examples["labels"] = fim_transform_ids |
| |
| |
| |
| examples["attention_mask"] = [[1] * len(mask) for mask in examples["input_ids"]] |
| return examples |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| with training_args.main_process_first(desc="processing texts together"): |
| if not data_args.streaming: |
| fim_datasets = tokenized_datasets.map( |
| apply_fim, |
| batched=True, |
| num_proc=data_args.preprocessing_num_workers, |
| load_from_cache_file=not data_args.overwrite_cache, |
| desc="Performing FIM transformation", |
| ) |
| lm_datasets = fim_datasets.map( |
| group_texts, |
| batched=True, |
| num_proc=data_args.preprocessing_num_workers, |
| load_from_cache_file=not data_args.overwrite_cache, |
| desc=f"Grouping texts in chunks of {block_size}", |
| ) |
| else: |
| fim_datasets = tokenized_datasets.map( |
| apply_fim, |
| batched=True, |
| ) |
| lm_datasets = fim_datasets.map( |
| group_texts, |
| batched=True, |
| ) |
|
|
| if training_args.do_train: |
| if "train" not in tokenized_datasets: |
| raise ValueError("--do_train requires a train dataset") |
| train_dataset = lm_datasets["train"] |
| if data_args.max_train_samples is not None: |
| max_train_samples = min(len(train_dataset), data_args.max_train_samples) |
| train_dataset = train_dataset.select(range(max_train_samples)) |
|
|
| if training_args.do_eval: |
| if "validation" not in tokenized_datasets: |
| raise ValueError("--do_eval requires a validation dataset") |
| eval_dataset = lm_datasets["validation"] |
| if data_args.max_eval_samples is not None: |
| max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) |
| eval_dataset = eval_dataset.select(range(max_eval_samples)) |
|
|
| def preprocess_logits_for_metrics(logits, labels): |
| if isinstance(logits, tuple): |
| |
| |
| logits = logits[0] |
| return logits.argmax(dim=-1) |
|
|
| metric = evaluate.load("accuracy") |
|
|
| def compute_metrics(eval_preds): |
| preds, labels = eval_preds |
| |
| |
| labels = labels[:, 1:].reshape(-1) |
| preds = preds[:, :-1].reshape(-1) |
| return metric.compute(predictions=preds, references=labels) |
|
|
| |
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset if training_args.do_train else None, |
| eval_dataset=eval_dataset if training_args.do_eval else None, |
| processing_class=tokenizer, |
| |
| data_collator=default_data_collator, |
| compute_metrics=compute_metrics |
| if training_args.do_eval and not is_torch_xla_available(check_is_tpu=True) |
| else None, |
| preprocess_logits_for_metrics=( |
| preprocess_logits_for_metrics |
| if training_args.do_eval and not is_torch_xla_available(check_is_tpu=True) |
| else None |
| ), |
| ) |
|
|
| |
| if training_args.do_train: |
| checkpoint = None |
| if training_args.resume_from_checkpoint is not None: |
| checkpoint = training_args.resume_from_checkpoint |
| elif last_checkpoint is not None: |
| checkpoint = last_checkpoint |
| train_result = trainer.train(resume_from_checkpoint=checkpoint) |
| trainer.save_model() |
|
|
| metrics = train_result.metrics |
|
|
| max_train_samples = ( |
| data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) |
| ) |
| metrics["train_samples"] = min(max_train_samples, len(train_dataset)) |
|
|
| trainer.log_metrics("train", metrics) |
| trainer.save_metrics("train", metrics) |
| trainer.save_state() |
|
|
| |
| if training_args.do_eval: |
| logger.info("*** Evaluate ***") |
|
|
| metrics = trainer.evaluate() |
|
|
| max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) |
| metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) |
| try: |
| perplexity = math.exp(metrics["eval_loss"]) |
| except OverflowError: |
| perplexity = float("inf") |
| metrics["perplexity"] = perplexity |
|
|
| trainer.log_metrics("eval", metrics) |
| trainer.save_metrics("eval", metrics) |
|
|
| kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"} |
| if data_args.dataset_name is not None: |
| kwargs["dataset_tags"] = data_args.dataset_name |
| if data_args.dataset_config_name is not None: |
| kwargs["dataset_args"] = data_args.dataset_config_name |
| kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" |
| else: |
| kwargs["dataset"] = data_args.dataset_name |
|
|
| if training_args.push_to_hub: |
| trainer.push_to_hub(**kwargs) |
| else: |
| trainer.create_model_card(**kwargs) |
|
|
|
|
| def _mp_fn(index): |
| |
| main() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|