| | |
| | from transformers import LEDConfig, LEDModel, LEDPreTrainedModel |
| | from transformers.modeling_outputs import TokenClassifierOutput |
| |
|
| | import torch.nn as nn |
| |
|
| | class CustomLEDForResultsIdModel(LEDPreTrainedModel): |
| | def __init__(self, config: LEDConfig, checkpoint=None): |
| | super().__init__(config) |
| | self.num_labels = config.num_labels |
| | print("Configs") |
| | print(config.num_labels) |
| | print(config.dropout) |
| |
|
| | |
| | if (checkpoint): |
| | self.led = LEDModel.from_pretrained(checkpoint, config=config).get_encoder() |
| | else: |
| | self.led = LEDModel(config).get_encoder() |
| | |
| | |
| | self.dropout = nn.Dropout(config.dropout) |
| | self.classifier = nn.Linear(self.led.config.d_model,self.num_labels) |
| |
|
| | def forward(self, input_ids=None, attention_mask=None, labels=None, global_attention_mask=None, return_loss=True): |
| | |
| | outputs = self.led(input_ids=input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask) |
| | |
| | sequence_output = self.dropout(outputs.last_hidden_state) |
| | logits = self.classifier(sequence_output) |
| |
|
| | |
| | |
| | |
| | |
| | loss = None |
| | if labels is not None: |
| | loss_fct = nn.CrossEntropyLoss() |
| | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
| |
|
| | return { |
| | 'loss': loss, |
| | 'logits': logits |
| | } |