File size: 4,217 Bytes
eca55dc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | import rootutils
import hydra
from omegaconf import DictConfig
import lightning as L
import torch
from pathlib import Path
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from typing import List, Dict, Any
# Setup root
root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from src.utils import instantiate_callbacks, instantiate_loggers, RankedLogger, extras # noqa: E402
log = RankedLogger(__name__, rank_zero_only=True)
@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
def main(cfg: DictConfig) -> Dict[str, Any]:
# Set seed
if cfg.get("seed"):
L.seed_everything(cfg.seed, workers=True)
# Applies optional utilities
extras(cfg)
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
datamodule: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
log.info(f"Instantiating model <{cfg.model._target_}>")
model: L.LightningModule = hydra.utils.instantiate(cfg.model)
log.info("Instantiating callbacks...")
callbacks: List[L.Callback] = instantiate_callbacks(cfg.get("callbacks"))
callbacks_cfg = cfg.get("callbacks")
if (
isinstance(callbacks_cfg, DictConfig)
and "model_checkpoint" in callbacks_cfg
and callbacks_cfg.model_checkpoint is None
):
log.warning(
"`callbacks.model_checkpoint` is null in the composed config. "
"Lightning will use its default ModelCheckpoint callback, which may not "
"save `last.ckpt` and can change filename conventions. Remove the null "
"override or set explicit checkpoint fields in the experiment config."
)
if cfg.get("train") and not any(
isinstance(callback, ModelCheckpoint) for callback in callbacks
):
log.warning(
"No explicit ModelCheckpoint callback was instantiated from config; "
"Lightning default checkpointing behavior will be used."
)
log.info("Instantiating loggers...")
logger: List[L.Logger] = instantiate_loggers(cfg.get("logger"))
# Set float32 matmul precision for Tensor Cores
torch.set_float32_matmul_precision("medium")
# Log config tree and .hydra folder to wandb
for lg in logger:
if isinstance(lg, WandbLogger):
# check if config_tree.log exists
config_tree_path = Path(cfg.paths.output_dir, "config_tree.log")
if config_tree_path.exists():
log.info("Logging config tree to WandB...")
lg.experiment.save(
str(config_tree_path), policy="now", base_path=cfg.paths.output_dir
)
# Upload .hydra folder contents
hydra_dir = Path(cfg.paths.output_dir, ".hydra")
if hydra_dir.exists() and hydra_dir.is_dir():
log.info("Logging .hydra folder to WandB...")
for hydra_file in hydra_dir.iterdir():
if hydra_file.is_file():
lg.experiment.save(
str(hydra_file),
policy="now",
base_path=cfg.paths.output_dir,
)
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: L.Trainer = hydra.utils.instantiate(
cfg.trainer,
callbacks=callbacks,
logger=logger,
)
object_dict = {
"cfg": cfg,
"datamodule": datamodule,
"model": model,
"callbacks": callbacks,
"logger": logger,
"trainer": trainer,
}
if cfg.get("train"):
log.info("Starting training!")
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
if cfg.get("test"):
log.info("Starting testing!")
ckpt_path = trainer.checkpoint_callback.best_model_path
if ckpt_path == "":
log.warning("Best ckpt not found! Using current weights for testing...")
ckpt_path = None
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
return object_dict
if __name__ == "__main__":
main()
|