| import torch |
| import pytorch_lightning as pl |
| from pytorch_lightning.loggers import WandbLogger |
| from pytorch_lightning.callbacks import ModelCheckpoint |
| import os |
| import sys |
| import argparse |
| import hydra |
| from omegaconf import OmegaConf |
| from datetime import datetime |
| |
| CONFIG_DIR = os.path.dirname(os.path.abspath(__file__)) |
| |
| sys.path.insert(0, os.path.dirname(CONFIG_DIR)) |
|
|
| import wandb |
| from lightning_modules import AnyOrderInsertionFlowModule |
|
|
|
|
| torch.set_printoptions(threshold=10_000) |
| torch.set_float32_matmul_precision("high") |
|
|
| |
| torch._dynamo.config.optimize_ddp = False |
|
|
| def train(config): |
| wandb_logger = None |
|
|
| |
| pl.seed_everything(42) |
| torch.manual_seed(42) |
| |
| |
| if int(os.environ.get("LOCAL_RANK", 0)) == 0: |
| wandb.init( |
| project=config.wandb.project, |
| name=config.wandb.name, |
| config=OmegaConf.to_container(config, resolve=True), |
| dir=config.wandb.path |
| ) |
| wandb_logger = WandbLogger( |
| project=wandb.run.project, |
| name=wandb.run.name, |
| log_model=False, |
| ) |
|
|
| |
| OmegaConf.set_struct(config, False) |
| time_string = datetime.now().strftime("%Y%m%d-%H%M%S") |
| config.training.checkpoint_dir = os.path.join( |
| config.training.checkpoint_dir, time_string |
| ) |
| OmegaConf.set_struct(config, True) |
|
|
| |
| os.makedirs(config.training.checkpoint_dir, exist_ok=True) |
| |
| |
| if hasattr(config, 'hf_dataset'): |
| |
| |
| from mol_dataset import setup_hf_data_and_update_config |
| print(f"Using HuggingFace dataset: {config.hf_dataset.name}") |
| data_module = setup_hf_data_and_update_config( |
| config, |
| dataset_name=config.hf_dataset.name, |
| smiles_column=config.hf_dataset.get('smiles_column', 'smiles') |
| ) |
| else: |
| |
| |
| from dataloading_for_dynamic_batching import setup_data_and_update_config |
| print("Using local dataset") |
| data_module = setup_data_and_update_config(config) |
| |
| module = AnyOrderInsertionFlowModule(config) |
| |
| |
|
|
| |
| |
| dtype_str = config.model.get('torch_dtype', 'bfloat16') |
| precision_map = { |
| 'float32': '32-true', |
| 'float16': '16-mixed', |
| 'bfloat16': 'bf16-mixed' |
| } |
| precision = precision_map.get(dtype_str, 'bf16-mixed') |
| |
| trainer_kwargs = dict( |
| num_nodes=config.training.nodes, |
| accelerator="gpu", |
| devices=config.training.devices, |
| strategy="ddp", |
| precision=precision, |
| accumulate_grad_batches=( |
| config.training.batch_size |
| // ( |
| config.training.per_gpu_batch_size |
| * config.training.nodes |
| * config.training.devices |
| ) |
| ), |
| log_every_n_steps=10, |
| enable_checkpointing=True, |
| default_root_dir=config.training.checkpoint_dir, |
| gradient_clip_val=1.0, |
| ) |
| |
| if config.training.max_steps is not None: |
| trainer_kwargs["max_steps"] = config.training.max_steps |
| elif config.training.num_epochs is not None: |
| trainer_kwargs["max_epochs"] = config.training.num_epochs |
| config.training.max_steps = config.training.max_steps |
| else: |
| raise ValueError( |
| "Either max_steps or num_epochs must be specified in the config" |
| ) |
|
|
| if config.training.warmup_steps is None: |
| config.training.warmup_steps = int(config.training.max_steps * 0.01) |
|
|
| |
| checkpoint_callback = ModelCheckpoint( |
| monitor="train/total_loss", |
| mode="min", |
| save_top_k=config.training.save_top_k, |
| save_last=True, |
| filename="epoch-{epoch:02d}-train_loss-{train/total_loss:.4f}", |
| dirpath=config.training.checkpoint_dir, |
| |
| auto_insert_metric_name=False |
| ) |
| |
| |
| |
| |
| save_every_n_steps = config.training.get('save_every_n_steps', None) |
| save_every_n_epochs = config.training.get('save_every_n_epochs', None) |
| if save_every_n_steps is not None: |
| periodic_checkpoint_callback = ModelCheckpoint( |
| save_top_k=-1, |
| filename="step-{step:08d}", |
| dirpath=config.training.checkpoint_dir, |
| every_n_train_steps=save_every_n_steps, |
| auto_insert_metric_name=False |
| ) |
| elif save_every_n_epochs is not None: |
| periodic_checkpoint_callback = ModelCheckpoint( |
| save_top_k=-1, |
| filename="epoch-{epoch:02d}", |
| dirpath=config.training.checkpoint_dir, |
| every_n_epochs=save_every_n_epochs, |
| auto_insert_metric_name=False |
| ) |
| else: |
| raise ValueError( |
| "Either save_every_n_steps or save_every_n_epochs must be specified in the config" |
| ) |
|
|
| trainer_kwargs["callbacks"] = [checkpoint_callback, periodic_checkpoint_callback] |
|
|
| if wandb_logger is not None: |
| trainer_kwargs["logger"] = wandb_logger |
|
|
| trainer = pl.Trainer(**trainer_kwargs) |
|
|
| |
| ckpt_path = None |
| if "resume_path" in config.training: |
| ckpt_path = config.training.resume_path |
| |
| trainer.fit(module, |
| datamodule=data_module, |
| ckpt_path=ckpt_path) |
| |
| |
| if int(os.environ.get("LOCAL_RANK", 0)) == 0: |
| wandb.finish() |
|
|
|
|
| if __name__ == '__main__': |
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--config_name', type=str, default='config', |
| help='Name of the config file to use') |
| parser.add_argument('--task', type=str, default=None, |
| help='Task name (uses config_{task}.yaml)') |
| |
| |
| args, unknown = parser.parse_known_args() |
| |
| |
| if args.task: |
| config_name = f'config_{args.task}' |
| else: |
| config_name = args.config_name |
| |
| print(f"Using config: {config_name}.yaml") |
| |
| |
| if '--config-name' not in unknown and f'--config-name={config_name}' not in unknown: |
| unknown.insert(0, f'--config-name={config_name}') |
| |
| |
| sys.argv = [sys.argv[0]] + unknown |
| |
| |
| @hydra.main(version_base=None, |
| config_path=CONFIG_DIR, |
| config_name='config') |
| def main(config): |
| """Main entry point for training""" |
| train(config) |
| |
| main() |