File size: 3,800 Bytes
eb6d243 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | # Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import contextlib
from typing import Optional
import numpy as np
from unicore.data import (
Dictionary,
MaskTokensDataset,
NestedDictionaryDataset,
NumelDataset,
NumSamplesDataset,
LMDBDataset,
PrependTokenDataset,
RightPadDataset,
SortDataset,
BertTokenizeDataset,
data_utils,
)
from unicore.tasks import UnicoreTask, register_task
logger = logging.getLogger(__name__)
@register_task("bert")
class BertTask(UnicoreTask):
"""Task for training masked language models (e.g., BERT)."""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument(
"data",
help="colon separated path to data directories list, \
will be iterated upon during epochs in round-robin manner",
)
parser.add_argument(
"--mask-prob",
default=0.15,
type=float,
help="probability of replacing a token with mask",
)
parser.add_argument(
"--leave-unmasked-prob",
default=0.1,
type=float,
help="probability that a masked token is unmasked",
)
parser.add_argument(
"--random-token-prob",
default=0.1,
type=float,
help="probability of replacing a token with a random token",
)
def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
self.seed = args.seed
# add mask token
self.mask_idx = dictionary.add_symbol("[MASK]", is_special=True)
@classmethod
def setup_task(cls, args, **kwargs):
dictionary = Dictionary.load(os.path.join(args.data, "dict.txt"))
logger.info("dictionary: {} types".format(len(dictionary)))
return cls(args, dictionary)
def load_dataset(self, split, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
split_path = os.path.join(self.args.data, split + '.lmdb')
dict_path = os.path.join(self.args.data, "dict.txt")
dataset = LMDBDataset(split_path)
dataset = BertTokenizeDataset(dataset, dict_path, max_seq_len=self.args.max_seq_len)
src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
dataset,
self.dictionary,
pad_idx=self.dictionary.pad(),
mask_idx=self.mask_idx,
seed=self.args.seed,
mask_prob=self.args.mask_prob,
leave_unmasked_prob=self.args.leave_unmasked_prob,
random_token_prob=self.args.random_token_prob,
)
with data_utils.numpy_seed(self.args.seed):
shuffle = np.random.permutation(len(src_dataset))
self.datasets[split] = SortDataset(
NestedDictionaryDataset(
{
"net_input": {
"src_tokens": RightPadDataset(
src_dataset,
pad_idx=self.dictionary.pad(),
)
},
"target": RightPadDataset(
tgt_dataset,
pad_idx=self.dictionary.pad(),
),
},
),
sort_order=[
shuffle
],
)
def build_model(self, args):
from unicore import models
model = models.build_model(args, self)
return model
|