| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import logging |
| | from dataclasses import dataclass, field |
| | from typing import Optional |
| |
|
| | from seq2seq_trainer import arg_to_scheduler |
| |
|
| | from transformers import TrainingArguments |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | @dataclass |
| | class Seq2SeqTrainingArguments(TrainingArguments): |
| | """ |
| | Parameters: |
| | label_smoothing (:obj:`float`, `optional`, defaults to 0): |
| | The label smoothing epsilon to apply (if not zero). |
| | sortish_sampler (:obj:`bool`, `optional`, defaults to :obj:`False`): |
| | Whether to SortishSampler or not. It sorts the inputs according to lengths in-order to minimizing the padding size. |
| | predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`): |
| | Whether to use generate to calculate generative metrics (ROUGE, BLEU). |
| | """ |
| |
|
| | label_smoothing: Optional[float] = field( |
| | default=0.0, metadata={"help": "The label smoothing epsilon to apply (if not zero)."} |
| | ) |
| | sortish_sampler: bool = field(default=False, metadata={"help": "Whether to SortishSampler or not."}) |
| | predict_with_generate: bool = field( |
| | default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."} |
| | ) |
| | adafactor: bool = field(default=False, metadata={"help": "whether to use adafactor"}) |
| | encoder_layerdrop: Optional[float] = field( |
| | default=None, metadata={"help": "Encoder layer dropout probability. Goes into model.config."} |
| | ) |
| | decoder_layerdrop: Optional[float] = field( |
| | default=None, metadata={"help": "Decoder layer dropout probability. Goes into model.config."} |
| | ) |
| | dropout: Optional[float] = field(default=None, metadata={"help": "Dropout probability. Goes into model.config."}) |
| | attention_dropout: Optional[float] = field( |
| | default=None, metadata={"help": "Attention dropout probability. Goes into model.config."} |
| | ) |
| | lr_scheduler: Optional[str] = field( |
| | default="linear", |
| | metadata={"help": f"Which lr scheduler to use. Selected in {sorted(arg_to_scheduler.keys())}"}, |
| | ) |
| |
|