| import torch |
| import torch.nn as nn |
| from transformers import AutoModel, AutoConfig |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers import PretrainedConfig |
|
|
|
|
| class CustomConfig(PretrainedConfig): |
| model_type = "roberta" |
|
|
| def __init__( |
| self, |
| num_classes: int = 2, |
| **kwargs, |
| ): |
| self.num_classes = num_classes |
| super().__init__(**kwargs) |
|
|
|
|
| |
| |
| |
| |
| class MeanPooling(PreTrainedModel): |
| def __init__( |
| self, |
| config |
| |
| ): |
| super(MeanPooling, self).__init__(config) |
|
|
| def forward(self, last_hidden_state, attention_mask): |
| input_mask_expanded = ( |
| attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() |
| ) |
| sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1) |
| sum_mask = input_mask_expanded.sum(1) |
| sum_mask = torch.clamp(sum_mask, min=1e-9) |
| mean_embeddings = sum_embeddings / sum_mask |
| return mean_embeddings |
|
|
|
|
| |
| class CustomModel(PreTrainedModel): |
| config_class = CustomConfig |
|
|
| def __init__( |
| self, |
| cfg, |
| num_labels=2, |
| config_path=None, |
| pretrained=True, |
| binary_classification=True, |
| **kwargs, |
| ): |
| |
| self.cfg = cfg |
| self.num_labels = num_labels |
| if config_path is None: |
| self.config = AutoConfig.from_pretrained( |
| self.cfg.model_name, output_hidden_states=True |
| ) |
|
|
| else: |
| self.config = torch.load(config_path) |
|
|
| super().__init__(self.config) |
|
|
| if pretrained: |
| self.model = AutoModel.from_pretrained( |
| self.cfg.model_name, config=self.config |
| ) |
| else: |
| self.model = AutoModel(self.config) |
|
|
| if self.cfg.gradient_checkpointing: |
| self.model.gradient_checkpointing_enable() |
|
|
| self.pool = MeanPooling(config=self.config) |
|
|
| self.binary_classification = binary_classification |
|
|
| if self.binary_classification: |
| |
| self.fc = nn.Linear(self.config.hidden_size, self.num_labels - 1) |
| else: |
| self.fc = nn.Linear(self.config.hidden_size, self.num_labels) |
|
|
| self._init_weights(self.fc) |
|
|
| self.sigmoid_fn = nn.Sigmoid() |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
| elif isinstance(module, nn.LayerNorm): |
| module.bias.data.zero_() |
| module.weight.data.fill_(1.0) |
|
|
| def feature(self, input_ids, attention_mask, token_type_ids): |
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| ) |
| last_hidden_states = outputs[0] |
| feature = self.pool(last_hidden_states, attention_mask) |
| return feature |
|
|
| def forward(self, input_ids, attention_mask, token_type_ids): |
| feature = self.feature(input_ids, attention_mask, token_type_ids) |
| output = self.fc(feature) |
| if self.binary_classification: |
| |
| |
| |
| output = self.sigmoid_fn(output) |
|
|
| return output |
|
|