| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from torch import nn |
| | from transformers import AutoModel |
| | from transformers import PreTrainedModel |
| |
|
| | from .configuration_leaf import LeafConfig |
| | from .mappings import idx_to_ef, idx_to_classname |
| |
|
| |
|
| | class LeafModel(PreTrainedModel): |
| | """ |
| | LEAF model for text classification. |
| | """ |
| | config_class = LeafConfig |
| |
|
| | def __init__(self, config: LeafConfig): |
| | super().__init__(config) |
| | self._base_model = AutoModel.from_pretrained(config.model_name) |
| | self._device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | hidden_dim = self._base_model.config.hidden_size |
| | self.head = ClassificationHead(hidden_dim=hidden_dim, num_classes=2097, |
| | idx_to_ef=idx_to_ef, idx_to_classname=idx_to_classname, |
| | device=self._device) |
| |
|
| | def forward(self, input_ids, attention_mask, **kwargs) -> dict: |
| | if "classes" not in kwargs: |
| | kwargs["classes"] = None |
| | outputs = self._base_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state |
| | attention_mask = attention_mask.unsqueeze(-1) |
| | masked_outputs = outputs * attention_mask.type_as(outputs) |
| | nom = masked_outputs.sum(dim=1) |
| | denom = attention_mask.sum(dim=1) |
| | denom = denom.masked_fill(denom == 0, 1) |
| | return self.head(nom / denom, **kwargs) |
| |
|
| |
|
| | class ClassificationHead(nn.Module): |
| | """ |
| | Model head to predict a categorical target variable. |
| | """ |
| |
|
| | def __init__(self, hidden_dim: int, num_classes: int, idx_to_ef: dict, idx_to_classname: Optional[dict], |
| | device: str): |
| | super().__init__() |
| | self.linear = nn.Linear(in_features=hidden_dim, out_features=num_classes) |
| | self.loss = nn.CrossEntropyLoss() |
| |
|
| | |
| | self.idx_to_ef = torch.Tensor([idx_to_ef[k] for k in sorted(idx_to_ef.keys())]).to(device) |
| | self.idx_to_ef.requires_grad = False |
| | self.idx_to_classname = idx_to_classname |
| |
|
| | def __call__(self, activations: torch.Tensor, classes: Optional[torch.Tensor], **kwargs) -> dict: |
| | return_dict = {} |
| | logits = self.linear(activations) |
| | return_dict["logits"] = logits |
| | if classes: |
| | loss = self.loss(logits, classes) |
| | return_dict["loss"] = loss |
| | _, predicted_classes = torch.max(F.softmax(logits, dim=1), dim=1) |
| | return_dict["class_idx"] = predicted_classes |
| | return_dict["ef_score"] = self.idx_to_ef[predicted_classes] |
| | if self.idx_to_classname: |
| | return_dict["class"] = [self.idx_to_classname[str(c)] for c in |
| | predicted_classes.cpu().numpy()] |
| | return return_dict |
| |
|