File size: 2,174 Bytes
6d3b8ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d4693d
6d3b8ba
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from torch.utils.data import Dataset
from transformers.models.distilbert.modeling_distilbert import Transformer
from torch import nn
import torch

class CustomDataset(Dataset):
    def __init__(self, df, tokenizer, truncation='longest_first'):

        self.X = torch.tensor([tokenizer.encode(question + ' [SEP] ' + passage,
                              add_special_tokens=True,
                              padding='max_length',
                              truncation=truncation,
                              max_length=512)
                              for question, passage in zip(df.question.values,
                                                          df.passage.values)],
                              dtype=torch.int64)

        self.y = torch.tensor(df.answer.values, dtype=torch.float32)

        return

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        
        return self.X[idx], self.y[idx]


class QuestAnsweringDistilBERT(nn.Module):

  def __init__(self, bert, config, freeze_type, n_layers=2,
               hidden_dim=3072, dim=768):
    
        super(QuestAnsweringDistilBERT, self).__init__()
               
        self.bert = bert
        self.config = config
        self.config.n_layers = n_layers
        self.config.hidden_dim = hidden_dim
        self.config.dim = dim

        self.heads = Transformer(self.config)
        self.dense = nn.Linear(self.config.dim * self.config.max_position_embeddings,
                               1)

        if freeze_type == 'all':

          for param in self.bert.parameters():
            param.requires_grad = False

        elif freeze_type in ('emb', 'part'):
          
          for param in self.bert.embeddings.parameters():
            param.requires_grad = False

  def forward(self, x):

        part_mask = torch.where(x != 0, torch.tensor(1), torch.tensor(0))

        x = self.bert(x, attention_mask=part_mask).last_hidden_state
        
        head_mask = torch.ones((self.config.n_layers,))
        x = self.heads(x, attn_mask=part_mask, head_mask=head_mask)[0]
        x = self.dense(x.flatten(start_dim=1))
        return x.flatten()