| | from os.path import join |
| |
|
| | import torch |
| | from pytorch_lightning import LightningDataModule |
| | from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn |
| | from torch.utils.data import Subset |
| | from torch_geometric.loader import DataLoader |
| | from torch_scatter import scatter |
| | from tqdm import tqdm |
| |
|
| | from visnet.datasets import * |
| | from visnet.utils import MissingLabelException, make_splits |
| |
|
| |
|
| | class DataModule(LightningDataModule): |
| | def __init__(self, hparams): |
| | super(DataModule, self).__init__() |
| | self.hparams.update(hparams.__dict__) if hasattr(hparams, "__dict__") else self.hparams.update(hparams) |
| | self._mean, self._std = None, None |
| | self._saved_dataloaders = dict() |
| | self.dataset = None |
| |
|
| | def prepare_dataset(self): |
| | |
| | assert hasattr(self, f"_prepare_{self.hparams['dataset']}_dataset"), f"Dataset {self.hparams['dataset']} not defined" |
| | dataset_factory = lambda t: getattr(self, f"_prepare_{t}_dataset")() |
| | self.idx_train, self.idx_val, self.idx_test = dataset_factory(self.hparams["dataset"]) |
| | |
| | print(f"train {len(self.idx_train)}, val {len(self.idx_val)}, test {len(self.idx_test)}") |
| | self.train_dataset = Subset(self.dataset, self.idx_train) |
| | self.val_dataset = Subset(self.dataset, self.idx_val) |
| | self.test_dataset = Subset(self.dataset, self.idx_test) |
| |
|
| | if self.hparams["standardize"]: |
| | self._standardize() |
| |
|
| | def train_dataloader(self): |
| | return self._get_dataloader(self.train_dataset, "train") |
| |
|
| | def val_dataloader(self): |
| | loaders = [self._get_dataloader(self.val_dataset, "val")] |
| | delta = 1 if self.hparams['reload'] == 1 else 2 |
| | if ( |
| | len(self.test_dataset) > 0 |
| | and (self.trainer.current_epoch + delta) % self.hparams["test_interval"] == 0 |
| | ): |
| | loaders.append(self._get_dataloader(self.test_dataset, "test")) |
| | return loaders |
| |
|
| | def test_dataloader(self): |
| | return self._get_dataloader(self.test_dataset, "test") |
| |
|
| | @property |
| | def atomref(self): |
| | if hasattr(self.dataset, "get_atomref"): |
| | return self.dataset.get_atomref() |
| | return None |
| |
|
| | @property |
| | def mean(self): |
| | return self._mean |
| |
|
| | @property |
| | def std(self): |
| | return self._std |
| |
|
| | def _get_dataloader(self, dataset, stage, store_dataloader=True): |
| | store_dataloader = (store_dataloader and not self.hparams["reload"]) |
| | if stage in self._saved_dataloaders and store_dataloader: |
| | return self._saved_dataloaders[stage] |
| |
|
| | if stage == "train": |
| | batch_size = self.hparams["batch_size"] |
| | shuffle = True |
| | elif stage in ["val", "test"]: |
| | batch_size = self.hparams["inference_batch_size"] |
| | shuffle = False |
| |
|
| | dl = DataLoader( |
| | dataset=dataset, |
| | batch_size=batch_size, |
| | shuffle=shuffle, |
| | num_workers=self.hparams["num_workers"], |
| | pin_memory=True, |
| | ) |
| |
|
| | if store_dataloader: |
| | self._saved_dataloaders[stage] = dl |
| | return dl |
| | |
| | @rank_zero_only |
| | def _standardize(self): |
| | def get_label(batch, atomref): |
| | if batch.y is None: |
| | raise MissingLabelException() |
| |
|
| | if atomref is None: |
| | return batch.y.clone() |
| |
|
| | atomref_energy = scatter(atomref[batch.z], batch.batch, dim=0) |
| | return (batch.y.squeeze() - atomref_energy.squeeze()).clone() |
| |
|
| | data = tqdm( |
| | self._get_dataloader(self.train_dataset, "val", store_dataloader=False), |
| | desc="computing mean and std", |
| | ) |
| | try: |
| | atomref = self.atomref if self.hparams["prior_model"] == "Atomref" else None |
| | ys = torch.cat([get_label(batch, atomref) for batch in data]) |
| | except MissingLabelException: |
| | rank_zero_warn( |
| | "Standardize is true but failed to compute dataset mean and " |
| | "standard deviation. Maybe the dataset only contains forces." |
| | ) |
| | return None |
| |
|
| | self._mean = ys.mean(dim=0) |
| | self._std = ys.std(dim=0) |
| | |
| | def _prepare_Chignolin_dataset(self): |
| | |
| | self.dataset = Chignolin(root=self.hparams["dataset_root"]) |
| | train_size = self.hparams["train_size"] |
| | val_size = self.hparams["val_size"] |
| | |
| | idx_train, idx_val, idx_test = make_splits( |
| | len(self.dataset), |
| | train_size, |
| | val_size, |
| | None, |
| | self.hparams["seed"], |
| | join(self.hparams["log_dir"], "splits.npz"), |
| | self.hparams["splits"], |
| | ) |
| |
|
| | return idx_train, idx_val, idx_test |
| | |
| | def _prepare_MD17_dataset(self): |
| | |
| | self.dataset = MD17(root=self.hparams["dataset_root"], dataset_arg=self.hparams["dataset_arg"]) |
| | train_size = self.hparams["train_size"] |
| | val_size = self.hparams["val_size"] |
| | |
| | idx_train, idx_val, idx_test = make_splits( |
| | len(self.dataset), |
| | train_size, |
| | val_size, |
| | None, |
| | self.hparams["seed"], |
| | join(self.hparams["log_dir"], "splits.npz"), |
| | self.hparams["splits"], |
| | ) |
| |
|
| | return idx_train, idx_val, idx_test |
| |
|
| | def _prepare_MD22_dataset(self): |
| | |
| | self.dataset = MD22(root=self.hparams["dataset_root"], dataset_arg=self.hparams["dataset_arg"]) |
| | train_val_size = self.dataset.molecule_splits[self.hparams["dataset_arg"]] |
| | train_size = round(train_val_size * 0.95) |
| | val_size = train_val_size - train_size |
| | |
| | idx_train, idx_val, idx_test = make_splits( |
| | len(self.dataset), |
| | train_size, |
| | val_size, |
| | None, |
| | self.hparams["seed"], |
| | join(self.hparams["log_dir"], "splits.npz"), |
| | self.hparams["splits"], |
| | ) |
| |
|
| | return idx_train, idx_val, idx_test |
| |
|
| | def _prepare_Molecule3D_dataset(self): |
| | |
| | self.dataset = Molecule3D(root=self.hparams["dataset_root"]) |
| | split_dict = self.dataset.get_idx_split(self.hparams['split_mode']) |
| | idx_train = split_dict['train'] |
| | idx_val = split_dict['valid'] |
| | idx_test = split_dict['test'] |
| | |
| | return idx_train, idx_val, idx_test |
| | |
| | def _prepare_QM9_dataset(self): |
| | |
| | self.dataset = QM9(root=self.hparams["dataset_root"], dataset_arg=self.hparams["dataset_arg"]) |
| | train_size = self.hparams["train_size"] |
| | val_size = self.hparams["val_size"] |
| | |
| | idx_train, idx_val, idx_test = make_splits( |
| | len(self.dataset), |
| | train_size, |
| | val_size, |
| | None, |
| | self.hparams["seed"], |
| | join(self.hparams["log_dir"], "splits.npz"), |
| | self.hparams["splits"], |
| | ) |
| |
|
| | return idx_train, idx_val, idx_test |
| | |
| | def _prepare_rMD17_dataset(self): |
| | |
| | self.dataset = rMD17(root=self.hparams["dataset_root"], dataset_arg=self.hparams["dataset_arg"]) |
| | train_size = self.hparams["train_size"] |
| | val_size = self.hparams["val_size"] |
| | |
| | idx_train, idx_val, idx_test = make_splits( |
| | len(self.dataset), |
| | train_size, |
| | val_size, |
| | None, |
| | self.hparams["seed"], |
| | join(self.hparams["log_dir"], "splits.npz"), |
| | self.hparams["splits"], |
| | ) |
| |
|
| | return idx_train, idx_val, idx_test |
| | |