| from torch.utils.data import DataLoader |
| import pytorch_lightning as pl |
| import wandb |
| from torch import nn |
| from pytorch_lightning.loggers import WandbLogger |
| from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor |
| from pytorch_lightning import Trainer |
|
|
| import pandas as pd |
| from loguru import logger |
| from dotenv import load_dotenv |
| import torch |
|
|
| from src.regression.datasets import DecoderDatasetTorch |
| from src.regression.datasets import regression_dataset |
| from src.regression.PL import * |
|
|
| load_dotenv() |
|
|
|
|
| def train_decoder_PL( |
| train: pd.DataFrame, |
| test: pd.DataFrame, |
| artifact_path: str | None = None, |
| resume: bool | str = "must", |
| run_id: str | None = None, |
| run_name: str = "sanity", |
| model_class=DecoderPL, |
| max_epochs: int = 2, |
| layer_norm: bool = True, |
| embedding_column: str = "my_full_mean_embedding", |
| device: str = "mps", |
| *args, |
| **kwargs |
| ): |
|
|
| torch.set_default_dtype(torch.float32) |
|
|
| train = train[train.aov.notna()].reset_index(drop=True) |
| test = test[test.aov.notna()].reset_index(drop=True) |
|
|
| if run_name == "sanity": |
| resume = False |
| run_id = None |
| max_epochs = 2 |
| train = train.loc[0:16, :] |
| test = test.loc[0:16] |
|
|
| |
| train_dataset = DecoderDatasetTorch(df=train, embedding_column=embedding_column) |
| train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=8) |
|
|
| test_dataset = DecoderDatasetTorch(df=test, embedding_column=embedding_column) |
| test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=8) |
|
|
| wandb_logger = WandbLogger( |
| project="transformers", |
| entity="sanjin_juric_fot", |
| log_model=True, |
| reinit=True, |
| resume=resume, |
| id=run_id, |
| name=run_name, |
| ) |
|
|
| |
| if artifact_path is not None: |
| artifact = wandb_logger.use_artifact(artifact_path) |
| artifact_dir = artifact.download() |
| litmodel = model_class.load_from_checkpoint(artifact_dir + "/" + "model.ckpt").to(device) |
|
|
| logger.debug("logged from checkpoint") |
|
|
| torch.multiprocessing.set_sharing_strategy("file_system") |
|
|
| else: |
| litmodel = model_class(input_dim=len(train.at[0, embedding_column]), layer_norm=layer_norm, *args, **kwargs).to( |
| device |
| ) |
|
|
| checkpoint_callback = ModelCheckpoint(monitor="val_loss", mode="min") |
| lr_monitor = LearningRateMonitor(logging_interval="epoch") |
| trainer = Trainer( |
| accelerator=str(device), |
| devices=1, |
| logger=wandb_logger, |
| log_every_n_steps=2, |
| max_epochs=max_epochs, |
| callbacks=[checkpoint_callback, lr_monitor], |
| ) |
|
|
| logger.debug("training...") |
| trainer.fit( |
| model=litmodel, |
| train_dataloaders=train_dataloader, |
| val_dataloaders=test_dataloader, |
| ) |
|
|