| import emoji |
| import numpy as np |
| import pytorch_lightning as pl |
| import torch |
| import torch.nn.functional as F |
| from loguru import logger |
| from torch import nn |
| from torch.optim.lr_scheduler import CosineAnnealingLR |
| from torchmetrics import R2Score |
| from transformers import BertModel, BertTokenizerFast |
|
|
| from src.utils import get_sentiment, vectorise_dict |
| from src.utils.neural_networks import set_layer |
| from config import DEVICE |
|
|
| from .DecoderPL import DecoderPL |
| from .EncoderPL import EncoderPL |
|
|
|
|
| torch.set_default_dtype(torch.float32) |
|
|
|
|
| class FullModelPL(pl.LightningModule): |
| def __init__( |
| self, |
| model_name: str = "bert-base-uncased", |
| nontext_features: list[str] = ["aov"], |
| encoder: EncoderPL | None = None, |
| decoder: DecoderPL | None = None, |
| layer_norm: bool = True, |
| device=DEVICE, |
| T_max: int = 10, |
| ): |
| super().__init__() |
|
|
| |
| self.encoder = ( |
| encoder.to(self.device) |
| if encoder is not None |
| else EncoderPL(model_name=model_name, device=device).to(self.device) |
| ) |
| self.decoder = ( |
| decoder.to(self.device) |
| if decoder is not None |
| else DecoderPL( |
| input_dim=768 + len(nontext_features) + 5, |
| layer_norm=layer_norm, |
| device=device, |
| ).to(self.device) |
| ) |
|
|
| |
| self.MSE = nn.MSELoss() |
| self.R2 = R2Score() |
|
|
| self.optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=3 * 1e-4) |
| self.scheduler = CosineAnnealingLR(self.optimizer, T_max=T_max) |
|
|
| |
|
|
| def forward(self, input_dict: dict): |
|
|
| input_dict = input_dict.copy() |
| text = input_dict.pop("text") |
|
|
| print(f"text: {text}") |
|
|
| if "ctr" in input_dict.keys(): |
| input_dict.pop("ctr") |
|
|
| |
| sentence_embedding = self.encoder.forward(text=text) |
|
|
| |
| sentiment = get_sentiment_for_list_of_texts(text) |
| input_dict = input_dict | sentiment |
|
|
| input_dict = {k: v.to(self.device) for k, v in input_dict.items()} |
|
|
| |
| nontext_vec = vectorise_dict(input_dict) |
| nontext_tensor = torch.stack(nontext_vec).T.unsqueeze(1).to(torch.float32) |
| |
| print(f"{sentence_embedding.get_device()}, {nontext_tensor.get_device()}") |
| x = torch.cat((sentence_embedding, nontext_tensor), 2) |
|
|
| print(self.decoder.device) |
| print(x.get_device()) |
|
|
| |
| result = self.decoder.forward(x) |
| return result |
|
|
| def training_step(self, batch): |
|
|
| loss_and_metrics = self._get_loss(batch, get_metrics=True) |
| pred = loss_and_metrics["pred"] |
| act = loss_and_metrics["act"] |
| loss = loss_and_metrics["loss"] |
|
|
| self.log("train_loss", loss, on_epoch=True, on_step=False, prog_bar=True, logger=True) |
|
|
| return {"loss": loss, "pred": pred, "act": act} |
|
|
| def configure_optimizers(self): |
|
|
| for name, param in self.named_parameters(): |
| if "bert" in name: |
| param.requires_grad = False |
|
|
| optimizer = self.optimizer |
| scheduler = self.scheduler |
| return dict(optimizer=optimizer, lr_scheduler=scheduler) |
|
|
| def lr_scheduler_step(self, scheduler, optimizer_idx, metric): |
| logger.debug(scheduler) |
| if metric is None: |
| scheduler.step() |
| else: |
| scheduler.step(metric) |
|
|
| def validation_step(self, batch, batch_idx): |
| """used for logging metrics""" |
| loss_and_metrics = self._get_loss(batch, get_metrics=True) |
| loss = loss_and_metrics["loss"] |
|
|
| |
| self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True) |
|
|
| def training_epoch_end(self, training_step_outputs): |
|
|
| training_step_outputs = list(training_step_outputs) |
|
|
| training_step_outputs.pop() |
|
|
| output_dict = {k: [dic[k] for dic in training_step_outputs] for k in training_step_outputs[0]} |
|
|
| pred = torch.stack(output_dict["pred"]) |
| act = torch.stack(output_dict["act"]) |
|
|
| loss = torch.sub(pred, act) |
| loss_sq = torch.square(loss) |
|
|
| TSS = float(torch.var(act, unbiased=False)) |
| RSS = float(torch.mean(loss_sq)) |
| R2 = 1 - RSS / TSS |
|
|
| self.log("train_R2", R2, prog_bar=True, logger=True) |
|
|
| def _get_loss(self, batch, get_metrics: bool = False): |
| """convenience function since train/valid/test steps are similar""" |
| pred = self.forward(input_dict=batch).to(torch.float32) |
|
|
| act, loss = None, None |
|
|
| if "ctr" in batch.keys(): |
| act = batch["ctr"].to(torch.float32).to(self.device) |
| loss = self.MSE(pred, act).to(torch.float32) |
|
|
| return {"loss": loss, "pred": pred, "act": act} |
|
|
|
|
| def get_sentiment_for_list_of_texts(texts: list[str]) -> dict: |
| ld = [get_sentiment(text) for text in texts] |
| v = {k: torch.Tensor([dic[k] for dic in ld]) for k in ld[0]} |
| return v |
|
|