| | import os |
| | import GPUtil |
| | import torch |
| | import sys |
| | import hydra |
| | import wandb |
| |
|
| | |
| | from pytorch_lightning import LightningDataModule, LightningModule, Trainer |
| | from pytorch_lightning.loggers.wandb import WandbLogger |
| | from pytorch_lightning.trainer import Trainer |
| | from pytorch_lightning.callbacks import ModelCheckpoint |
| |
|
| | from omegaconf import DictConfig, OmegaConf |
| | from data.pdb_dataloader import PdbDataModule |
| | from models.flow_module import FlowModule |
| | from experiments import utils as eu |
| |
|
| |
|
| | os.environ["WANDB_MODE"] = "offline" |
| | log = eu.get_pylogger(__name__) |
| | torch.set_float32_matmul_precision('high') |
| |
|
| |
|
| | class Experiment: |
| |
|
| | def __init__(self, *, cfg: DictConfig): |
| | self._cfg = cfg |
| | self._data_cfg = cfg.data |
| | self._exp_cfg = cfg.experiment |
| | self._datamodule: LightningDataModule = PdbDataModule(self._data_cfg) |
| | self._model: LightningModule = FlowModule(self._cfg) |
| | |
| | def train(self): |
| | callbacks = [] |
| | if self._exp_cfg.debug: |
| | log.info("Debug mode.") |
| | logger = None |
| | self._exp_cfg.num_devices = 1 |
| | self._data_cfg.loader.num_workers = 0 |
| | else: |
| | logger = WandbLogger( |
| | **self._exp_cfg.wandb, |
| | ) |
| | |
| | |
| | ckpt_dir = self._exp_cfg.checkpointer.dirpath |
| | os.makedirs(ckpt_dir, exist_ok=True) |
| | log.info(f"Checkpoints saved to {ckpt_dir}") |
| | |
| | |
| | callbacks.append(ModelCheckpoint(**self._exp_cfg.checkpointer)) |
| | |
| | |
| | cfg_path = os.path.join(ckpt_dir, 'config.yaml') |
| | with open(cfg_path, 'w') as f: |
| | OmegaConf.save(config=self._cfg, f=f.name) |
| | cfg_dict = OmegaConf.to_container(self._cfg, resolve=True) |
| | flat_cfg = dict(eu.flatten_dict(cfg_dict)) |
| | if isinstance(logger.experiment.config, wandb.sdk.wandb_config.Config): |
| | logger.experiment.config.update(flat_cfg) |
| |
|
| | devices = GPUtil.getAvailable(order='memory', limit = 8)[:self._exp_cfg.num_devices] |
| | log.info(f"Using devices: {devices}") |
| | trainer = Trainer( |
| | **self._exp_cfg.trainer, |
| | callbacks=callbacks, |
| | logger=logger, |
| | use_distributed_sampler=False, |
| | enable_progress_bar=True, |
| | enable_model_summary=True, |
| | devices=devices, |
| | ) |
| |
|
| | if self._exp_cfg.warm_start is not None: |
| | |
| | self._model = self._model.load_from_checkpoint(self._exp_cfg.warm_start, strict=False, map_location="cpu") |
| |
|
| | trainer.fit( |
| | model=self._model, |
| | datamodule=self._datamodule, |
| | |
| | ) |
| |
|
| |
|
| | @hydra.main(version_base=None, config_path="../configs", config_name="base.yaml") |
| | def main(cfg: DictConfig): |
| |
|
| | if cfg.experiment.warm_start is not None and cfg.experiment.warm_start_cfg_override: |
| | |
| | warm_start_cfg_path = os.path.join( |
| | os.path.dirname(cfg.experiment.warm_start), 'config.yaml') |
| | warm_start_cfg = OmegaConf.load(warm_start_cfg_path) |
| |
|
| | |
| | |
| | OmegaConf.set_struct(cfg.model, False) |
| | OmegaConf.set_struct(warm_start_cfg.model, False) |
| | cfg.model = OmegaConf.merge(cfg.model, warm_start_cfg.model) |
| | OmegaConf.set_struct(cfg.model, True) |
| | log.info(f'Loaded warm start config from {warm_start_cfg_path}') |
| |
|
| | exp = Experiment(cfg=cfg) |
| | exp.train() |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|