| | |
| | |
| | |
| |
|
| | import logging |
| | import os |
| |
|
| | import contextlib |
| | from typing import Optional |
| |
|
| | import numpy as np |
| | from unicore.data import ( |
| | Dictionary, |
| | MaskTokensDataset, |
| | NestedDictionaryDataset, |
| | NumelDataset, |
| | NumSamplesDataset, |
| | LMDBDataset, |
| | PrependTokenDataset, |
| | RightPadDataset, |
| | SortDataset, |
| | BertTokenizeDataset, |
| | data_utils, |
| | ) |
| | from unicore.tasks import UnicoreTask, register_task |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | @register_task("bert") |
| | class BertTask(UnicoreTask): |
| | """Task for training masked language models (e.g., BERT).""" |
| |
|
| | @staticmethod |
| | def add_args(parser): |
| | """Add task-specific arguments to the parser.""" |
| | parser.add_argument( |
| | "data", |
| | help="colon separated path to data directories list, \ |
| | will be iterated upon during epochs in round-robin manner", |
| | ) |
| | parser.add_argument( |
| | "--mask-prob", |
| | default=0.15, |
| | type=float, |
| | help="probability of replacing a token with mask", |
| | ) |
| | parser.add_argument( |
| | "--leave-unmasked-prob", |
| | default=0.1, |
| | type=float, |
| | help="probability that a masked token is unmasked", |
| | ) |
| | parser.add_argument( |
| | "--random-token-prob", |
| | default=0.1, |
| | type=float, |
| | help="probability of replacing a token with a random token", |
| | ) |
| |
|
| | def __init__(self, args, dictionary): |
| | super().__init__(args) |
| | self.dictionary = dictionary |
| | self.seed = args.seed |
| |
|
| | |
| | self.mask_idx = dictionary.add_symbol("[MASK]", is_special=True) |
| |
|
| | @classmethod |
| | def setup_task(cls, args, **kwargs): |
| | dictionary = Dictionary.load(os.path.join(args.data, "dict.txt")) |
| | logger.info("dictionary: {} types".format(len(dictionary))) |
| | return cls(args, dictionary) |
| |
|
| | def load_dataset(self, split, combine=False, **kwargs): |
| | """Load a given dataset split. |
| | Args: |
| | split (str): name of the split (e.g., train, valid, test) |
| | """ |
| | split_path = os.path.join(self.args.data, split + '.lmdb') |
| | dict_path = os.path.join(self.args.data, "dict.txt") |
| |
|
| | dataset = LMDBDataset(split_path) |
| | dataset = BertTokenizeDataset(dataset, dict_path, max_seq_len=self.args.max_seq_len) |
| |
|
| | src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( |
| | dataset, |
| | self.dictionary, |
| | pad_idx=self.dictionary.pad(), |
| | mask_idx=self.mask_idx, |
| | seed=self.args.seed, |
| | mask_prob=self.args.mask_prob, |
| | leave_unmasked_prob=self.args.leave_unmasked_prob, |
| | random_token_prob=self.args.random_token_prob, |
| | ) |
| |
|
| | with data_utils.numpy_seed(self.args.seed): |
| | shuffle = np.random.permutation(len(src_dataset)) |
| |
|
| | self.datasets[split] = SortDataset( |
| | NestedDictionaryDataset( |
| | { |
| | "net_input": { |
| | "src_tokens": RightPadDataset( |
| | src_dataset, |
| | pad_idx=self.dictionary.pad(), |
| | ) |
| | }, |
| | "target": RightPadDataset( |
| | tgt_dataset, |
| | pad_idx=self.dictionary.pad(), |
| | ), |
| | }, |
| | ), |
| | sort_order=[ |
| | shuffle |
| | ], |
| | ) |
| |
|
| | def build_model(self, args): |
| | from unicore import models |
| | model = models.build_model(args, self) |
| | return model |
| |
|