| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel, AutoModel, AutoConfig, PretrainedConfig |
| import transformers |
|
|
|
|
| class DistilBertClassifier(PreTrainedModel): |
|
|
| def __init__(self, bert_config, model_name='distilbert-base-uncased', tokenizer_len=30528, freeze_bert=False): |
|
|
|
|
| super().__init__(bert_config) |
| D_in, H, D_out = 256, 50, 91 |
|
|
| self.bert = AutoModel.from_pretrained(model_name) |
| self.bert.resize_token_embeddings(tokenizer_len) |
| self.classifier = nn.Sequential( |
| nn.GELU(), |
| nn.Linear(self.bert.config.hidden_size, 300), |
| nn.GELU(), |
| nn.Dropout(0.05), |
| nn.Linear(300, 91) |
| ) |
|
|
| if freeze_bert: |
| for param in self.bert.parameters(): |
| param.requires_grad = False |
|
|
| def forward(self, input_ids, attention_mask): |
|
|
| outputs = self.bert(input_ids=input_ids, |
| attention_mask=attention_mask) |
|
|
| last_hidden_state_cls = outputs[0][:, 0, :] |
| logits = self.classifier(last_hidden_state_cls) |
| return logits |