| import emoji |
| import numpy as np |
| import pytorch_lightning as pl |
| import torch |
| import torch.nn.functional as F |
| from loguru import logger |
| from torch import nn |
| from torch.optim.lr_scheduler import CosineAnnealingLR |
| from torchmetrics import R2Score |
| from transformers import BertModel, BertTokenizer, DistilBertModel, AutoModel, AutoTokenizer |
| from pytorch_lightning import LightningModule |
|
|
|
|
| from src.utils.neural_networks import set_layer |
| from src.utils import add_emoji_tokens, add_new_line_token, vectorise_dict |
| from config import DEVICE |
|
|
| torch.set_default_dtype(torch.float32) |
|
|
|
|
| class EncoderPL(pl.LightningModule): |
| def __init__( |
| self, |
| model_name: str = "bert-base-uncased", |
| tokenizer: AutoTokenizer | None = None, |
| bert: AutoModel | None = None, |
| cls: bool = False, |
| device=DEVICE, |
| ): |
| super().__init__() |
|
|
| self._device = device |
| self.cls = cls |
| self.model_name = model_name |
|
|
| |
|
|
| self.tokenizer = tokenizer if tokenizer is not None else BertTokenizer.from_pretrained(model_name) |
|
|
| self.bert = bert if bert is not None else BertModel.from_pretrained(model_name) |
|
|
| if tokenizer is None: |
| self.tokenizer = add_emoji_tokens(self.tokenizer) |
| self.tokenizer = add_new_line_token(self.tokenizer) |
| self.bert.resize_token_embeddings(len(self.tokenizer)) |
|
|
| |
| self.optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=1e-3) |
|
|
| |
| self.bert.config.torch_dtype = "float32" |
|
|
| def forward(self, text: str): |
|
|
| |
| encoded = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True).to(self._device) |
|
|
| if type(self.bert) == DistilBertModel: |
| encoded.pop("token_type_ids") |
|
|
| bert_output = self.bert(**encoded) |
|
|
| if self.cls: |
| if hasattr(bert_output, "pooler_output") and bert_output.pooler_output is not None: |
| embedding = bert_output.pooler_output.unsqueeze(dim=1) |
| else: |
| embedding = bert_output.last_hidden_state[0, 0, :].unsqueeze(dim=0).unsqueeze(dim=0) |
| else: |
| last_hidden_state = bert_output.last_hidden_state |
|
|
| if last_hidden_state.dim() == 2: |
| last_hidden_state = last_hidden_state.unsqueeze(dim=0) |
|
|
| embedding = torch.matmul( |
| encoded["attention_mask"].type(torch.float32).view(-1, 1, 512), |
| last_hidden_state, |
| ) |
|
|
| return embedding |
|
|
| def configure_optimizers(self): |
| return self.optimizer |
|
|
|
|
| def get_bert_embedding( |
| text: str, as_list: bool = True, cls: bool = False, device=DEVICE, layer_dict: dict = {} |
| ) -> list: |
| encoder = EncoderPL(cls=cls, layer_dict=layer_dict).to(device) |
| embedding = encoder.forward(text) |
|
|
| if as_list: |
| embedding = embedding.tolist()[0][0] |
|
|
| return embedding |
|
|
|
|
| def get_concat_embedding( |
| text: str = None, |
| bert_embedding: list = [], |
| other_features: dict = {}, |
| cls: bool = False, |
| device=DEVICE, |
| layer_dict: dict = {}, |
| ) -> list: |
|
|
| if not len(bert_embedding): |
|
|
| if text is None: |
| raise ValueError("both text and embedding are empty!") |
| bert_embedding = get_bert_embedding(text=text, cls=cls, device=device, layer_dict=layer_dict) |
|
|
| other_features = vectorise_dict(other_features, as_list=True) |
|
|
| concat_vec = bert_embedding + other_features |
|
|
| return concat_vec |
|
|