A2D2 / a2d2_mol /train.py
Sophia
initial commit
8019be0
Raw
History Blame Contribute Delete
7.2 kB
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import os
import sys
import argparse
import hydra
from omegaconf import OmegaConf
from datetime import datetime
# Directory containing this file and the config_*.yaml files (used by Hydra below).
CONFIG_DIR = os.path.dirname(os.path.abspath(__file__))
# Add the repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve.
sys.path.insert(0, os.path.dirname(CONFIG_DIR))
import wandb
from lightning_modules import AnyOrderInsertionFlowModule
torch.set_printoptions(threshold=10_000)
torch.set_float32_matmul_precision("high")
# Disable DDP optimizer due to incompatibility with flex_attention higher-order ops
torch._dynamo.config.optimize_ddp = False
def train(config):
wandb_logger = None
# set the random seed
pl.seed_everything(42)
torch.manual_seed(42)
# Only initialize wandb on rank 0 to avoid multiple runs
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
wandb.init(
project=config.wandb.project,
name=config.wandb.name,
config=OmegaConf.to_container(config, resolve=True), # Convert to dict
dir=config.wandb.path
)
wandb_logger = WandbLogger(
project=wandb.run.project,
name=wandb.run.name,
log_model=False, # Disable checkpoint uploading to save disk space
)
# Modify config to add timestamp to checkpoint directory
OmegaConf.set_struct(config, False)
time_string = datetime.now().strftime("%Y%m%d-%H%M%S")
config.training.checkpoint_dir = os.path.join(
config.training.checkpoint_dir, time_string
)
OmegaConf.set_struct(config, True)
# Create checkpoint directory
os.makedirs(config.training.checkpoint_dir, exist_ok=True)
# Setup data module - check if using HuggingFace dataset
if hasattr(config, 'hf_dataset'):
# Imported lazily: the HF/SAFE path is only used by the molecule configs,
# which keep mol_dataset.py (and its `safe` dependency) in a2d2_mol/.
from mol_dataset import setup_hf_data_and_update_config
print(f"Using HuggingFace dataset: {config.hf_dataset.name}")
data_module = setup_hf_data_and_update_config(
config,
dataset_name=config.hf_dataset.name,
smiles_column=config.hf_dataset.get('smiles_column', 'smiles')
)
else:
# Imported lazily: the local (arrow) path is used by the peptide config,
# which keeps dataloading_for_dynamic_batching.py in a2d2_pep/.
from dataloading_for_dynamic_batching import setup_data_and_update_config
print("Using local dataset")
data_module = setup_data_and_update_config(config)
module = AnyOrderInsertionFlowModule(config)
# Initialize trainer
# Configure trainer arguments
# Map torch_dtype to Lightning precision
dtype_str = config.model.get('torch_dtype', 'bfloat16')
precision_map = {
'float32': '32-true',
'float16': '16-mixed',
'bfloat16': 'bf16-mixed'
}
precision = precision_map.get(dtype_str, 'bf16-mixed')
trainer_kwargs = dict(
num_nodes=config.training.nodes,
accelerator="gpu",
devices=config.training.devices,
strategy="ddp",
precision=precision,
accumulate_grad_batches=(
config.training.batch_size
// (
config.training.per_gpu_batch_size
* config.training.nodes
* config.training.devices
)
),
log_every_n_steps=10,
enable_checkpointing=True,
default_root_dir=config.training.checkpoint_dir,
gradient_clip_val=1.0,
)
# Only one of max_steps or max_epochs will be used
if config.training.max_steps is not None:
trainer_kwargs["max_steps"] = config.training.max_steps
elif config.training.num_epochs is not None:
trainer_kwargs["max_epochs"] = config.training.num_epochs
config.training.max_steps = config.training.max_steps
else:
raise ValueError(
"Either max_steps or num_epochs must be specified in the config"
)
if config.training.warmup_steps is None:
config.training.warmup_steps = int(config.training.max_steps * 0.01)
# Add ModelCheckpoint callback to save the checkpoint when validation loss is at a new low
checkpoint_callback = ModelCheckpoint(
monitor="train/total_loss",
mode="min",
save_top_k=config.training.save_top_k,
save_last=True,
filename="epoch-{epoch:02d}-train_loss-{train/total_loss:.4f}",
dirpath=config.training.checkpoint_dir,
# Don't use val_loss in filename for periodic saves - causes failures when val doesn't run
auto_insert_metric_name=False
)
# Add separate callback for periodic saves (no val_loss dependency). Use
# step-based saves for streaming datasets (save_every_n_steps) and epoch-based
# saves otherwise (save_every_n_epochs); whichever the config provides.
save_every_n_steps = config.training.get('save_every_n_steps', None)
save_every_n_epochs = config.training.get('save_every_n_epochs', None)
if save_every_n_steps is not None:
periodic_checkpoint_callback = ModelCheckpoint(
save_top_k=-1, # Save all periodic checkpoints
filename="step-{step:08d}",
dirpath=config.training.checkpoint_dir,
every_n_train_steps=save_every_n_steps,
auto_insert_metric_name=False
)
elif save_every_n_epochs is not None:
periodic_checkpoint_callback = ModelCheckpoint(
save_top_k=-1, # Save all periodic checkpoints
filename="epoch-{epoch:02d}",
dirpath=config.training.checkpoint_dir,
every_n_epochs=save_every_n_epochs,
auto_insert_metric_name=False
)
else:
raise ValueError(
"Either save_every_n_steps or save_every_n_epochs must be specified in the config"
)
trainer_kwargs["callbacks"] = [checkpoint_callback, periodic_checkpoint_callback]
if wandb_logger is not None:
trainer_kwargs["logger"] = wandb_logger
trainer = pl.Trainer(**trainer_kwargs)
# Train the model
ckpt_path = None
if "resume_path" in config.training:
ckpt_path = config.training.resume_path
trainer.fit(module,
datamodule=data_module,
ckpt_path=ckpt_path)
# Only finish wandb on rank 0
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
wandb.finish()
if __name__ == '__main__':
# Parse arguments to get config name
parser = argparse.ArgumentParser()
parser.add_argument('--config_name', type=str, default='config',
help='Name of the config file to use')
parser.add_argument('--task', type=str, default=None,
help='Task name (uses config_{task}.yaml)')
# Parse known args (hydra will handle the rest)
args, unknown = parser.parse_known_args()
# Determine config name from task or config_name
if args.task:
config_name = f'config_{args.task}'
else:
config_name = args.config_name
print(f"Using config: {config_name}.yaml")
# Add config name to Hydra overrides (this persists across DDP subprocesses)
if '--config-name' not in unknown and f'--config-name={config_name}' not in unknown:
unknown.insert(0, f'--config-name={config_name}')
# Reconstruct sys.argv for hydra
sys.argv = [sys.argv[0]] + unknown
# Define main function with default config (will be overridden by command line)
@hydra.main(version_base=None,
config_path=CONFIG_DIR,
config_name='config')
def main(config):
"""Main entry point for training"""
train(config)
main()