| | from torch.utils.data.dataset import Dataset |
| | from transformers.tokenization_utils import PreTrainedTokenizer |
| | from tqdm import tqdm |
| | import json |
| | from dataclasses import dataclass |
| | import torch |
| | from relogic.pretrainkit.datasets.utils import pad_and_tensorize_sequence |
| | import random |
| |
|
| | class TaBARTDataset(Dataset): |
| | """ |
| | This dataset is used for pretraining task on generation-based or retrieval-based |
| | text-schema pair examples. |
| | The fields that will be used is `question`, `table_info.header`, `entities`. |
| | We already make sure that every entity in `entities` will be in `table_info.header`. |
| | """ |
| | def __init__(self, |
| | tokenizer: PreTrainedTokenizer, |
| | file_path: str, |
| | col_token: str): |
| | self.examples = [] |
| | total = 0 |
| | valid = 0 |
| | with open(file_path, encoding="utf-8") as f: |
| | for line in tqdm(f): |
| | total += 1 |
| | example = json.loads(line) |
| | text = example["question"] |
| | schema = example["table_info"]["header"] |
| | tokens = [tokenizer.cls_token] + tokenizer.tokenize(text, add_prefix_space=True) + [col_token] |
| | column_spans = [] |
| | start_idx = len(tokens) |
| | for column in schema: |
| | column_tokens = tokenizer.tokenize(column.lower(), add_prefix_space=True) |
| | tokens.extend(column_tokens) |
| | column_spans.append((start_idx, start_idx + len(column_tokens))) |
| | tokens.append(col_token) |
| | start_idx += len(column_tokens) + 1 |
| | |
| | tokens[-1] = tokenizer.sep_token |
| | input_ids = tokenizer.convert_tokens_to_ids(tokens) |
| | entities = example["entities"] |
| | column_labels = [0] * len(schema) |
| | for entity in entities: |
| | if entity != "limit" and entity != "*": |
| | column_labels[schema.index(entity)] = 1 |
| | if len(input_ids) > 600: |
| | continue |
| | self.examples.append({ |
| | "input_ids": input_ids, |
| | "column_spans": column_spans, |
| | "column_labels": column_labels |
| | }) |
| | valid += 1 |
| | |
| | print("Total {} and Valid {}".format(total, valid)) |
| | def __len__(self): |
| | return len(self.examples) |
| |
|
| | def __getitem__(self, i): |
| | return self.examples[i] |
| |
|
| |
|
| | @dataclass |
| | class DataCollatorForTaBART: |
| | tokenizer: PreTrainedTokenizer |
| | task: str |
| | mlm_probability: float = 0.35 |
| |
|
| |
|
| |
|
| | def __post_init__(self): |
| | self.label_bos_id = self.tokenizer.cls_token_id |
| | self.label_eos_id = self.tokenizer.sep_token_id |
| |
|
| | def collate_batch(self, examples): |
| | input_ids_sequences = [example["input_ids"] for example in examples] |
| | padded_input_ids_tensor = pad_and_tensorize_sequence(input_ids_sequences, |
| | padding_value=self.tokenizer.pad_token_id) |
| | if self.task == "mlm": |
| | inputs, labels = self.mask_tokens(padded_input_ids_tensor.clone()) |
| | return { |
| | "task": "mlm", |
| | "input_ids": inputs, |
| | "labels": padded_input_ids_tensor, |
| | "pad_token_id": self.tokenizer.pad_token_id, |
| | "label_bos_id": self.tokenizer.bos_token_id, |
| | "label_eos_id": self.tokenizer.eos_token_id, |
| | "label_padding_id": self.tokenizer.pad_token_id} |
| | elif self.task == "col_pred": |
| | column_labels_sequences = [example["column_labels"] for example in examples] |
| | padded_label_ids_tensor = pad_and_tensorize_sequence(column_labels_sequences, |
| | padding_value=-100) |
| | column_spans_sequences = [example["column_spans"] for example in examples] |
| | padded_column_spans_tensor = pad_and_tensorize_sequence(column_spans_sequences, |
| | padding_value=(0, 1)) |
| | return { |
| | "task": "col_pred", |
| | "input_ids": padded_input_ids_tensor, |
| | "column_spans": padded_column_spans_tensor, |
| | "labels": padded_label_ids_tensor, |
| | "pad_token_id": self.tokenizer.pad_token_id} |
| | elif self.task == "mlm+col_pred": |
| | if random.random() < 0.6: |
| | inputs, labels = self.mask_tokens(padded_input_ids_tensor.clone()) |
| | return { |
| | "task": "mlm", |
| | "input_ids": inputs, |
| | "labels": padded_input_ids_tensor, |
| | "pad_token_id": self.tokenizer.pad_token_id, |
| | "label_bos_id": self.tokenizer.bos_token_id, |
| | "label_eos_id": self.tokenizer.eos_token_id, |
| | "label_padding_id": self.tokenizer.pad_token_id} |
| | else: |
| | column_labels_sequences = [example["column_labels"] for example in examples] |
| | padded_label_ids_tensor = pad_and_tensorize_sequence(column_labels_sequences, |
| | padding_value=-100) |
| | column_spans_sequences = [example["column_spans"] for example in examples] |
| | padded_column_spans_tensor = pad_and_tensorize_sequence(column_spans_sequences, |
| | padding_value=(0, 1)) |
| | return { |
| | "task": "col_pred", |
| | "input_ids": padded_input_ids_tensor, |
| | "column_spans": padded_column_spans_tensor, |
| | "labels": padded_label_ids_tensor, |
| | "pad_token_id": self.tokenizer.pad_token_id} |
| |
|
| | def mask_tokens(self, inputs): |
| | """ |
| | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. |
| | """ |
| |
|
| | if self.tokenizer.mask_token is None: |
| | raise ValueError( |
| | "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer." |
| | ) |
| |
|
| | labels = inputs.clone() |
| | |
| | probability_matrix = torch.full(labels.shape, self.mlm_probability) |
| | special_tokens_mask = [ |
| | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() |
| | ] |
| | probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) |
| | if self.tokenizer._pad_token is not None: |
| | padding_mask = labels.eq(self.tokenizer.pad_token_id) |
| | probability_matrix.masked_fill_(padding_mask, value=0.0) |
| | masked_indices = torch.bernoulli(probability_matrix).bool() |
| | labels[~masked_indices] = -100 |
| |
|
| | |
| | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices |
| | inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) |
| |
|
| | |
| | indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced |
| | random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) |
| | inputs[indices_random] = random_words[indices_random] |
| |
|
| | |
| | return inputs, labels |
| |
|
| |
|