| class EntityBertNet(nn.Module): | |
| def __init__(self): | |
| super(EntityBertNet, self).__init__() | |
| config = BertConfig.from_pretrained(TRAINED_WEIGHTS) | |
| self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config) | |
| self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES, NUM_CLASSES) | |
| def forward(self, input_ids, attn_mask, entity_indices): | |
| # BERT | |
| bert_output, _ = self.bert_base(input_ids=input_ids, attention_mask=attn_mask,return_dict=False) | |
| # max pooling at entity locations | |
| entity_pooled_output = EntityBertNet.pooled_output(bert_output, entity_indices) | |
| # fc layer (softmax activation done in loss function) | |
| x = self.fc(entity_pooled_output) | |
| return x | |
| @staticmethod | |
| def pooled_output(bert_output, indices): | |
| #print(bert_output) | |
| outputs = torch.gather(input=bert_output, dim=1, index=indices) | |
| pooled_output, _ = torch.max(outputs, dim=1) | |
| return pooled_output |