| import rootutils |
| import hydra |
| from omegaconf import DictConfig |
| import lightning as L |
| import torch |
| from pathlib import Path |
| from lightning.pytorch.loggers import WandbLogger |
| from lightning.pytorch.callbacks import ModelCheckpoint |
| from typing import List, Dict, Any |
|
|
| |
| root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) |
|
|
| from src.utils import instantiate_callbacks, instantiate_loggers, RankedLogger, extras |
|
|
| log = RankedLogger(__name__, rank_zero_only=True) |
|
|
|
|
| @hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") |
| def main(cfg: DictConfig) -> Dict[str, Any]: |
| |
| if cfg.get("seed"): |
| L.seed_everything(cfg.seed, workers=True) |
|
|
| |
| extras(cfg) |
|
|
| log.info(f"Instantiating datamodule <{cfg.data._target_}>") |
| datamodule: L.LightningDataModule = hydra.utils.instantiate(cfg.data) |
|
|
| log.info(f"Instantiating model <{cfg.model._target_}>") |
| model: L.LightningModule = hydra.utils.instantiate(cfg.model) |
|
|
| log.info("Instantiating callbacks...") |
| callbacks: List[L.Callback] = instantiate_callbacks(cfg.get("callbacks")) |
|
|
| callbacks_cfg = cfg.get("callbacks") |
| if ( |
| isinstance(callbacks_cfg, DictConfig) |
| and "model_checkpoint" in callbacks_cfg |
| and callbacks_cfg.model_checkpoint is None |
| ): |
| log.warning( |
| "`callbacks.model_checkpoint` is null in the composed config. " |
| "Lightning will use its default ModelCheckpoint callback, which may not " |
| "save `last.ckpt` and can change filename conventions. Remove the null " |
| "override or set explicit checkpoint fields in the experiment config." |
| ) |
|
|
| if cfg.get("train") and not any( |
| isinstance(callback, ModelCheckpoint) for callback in callbacks |
| ): |
| log.warning( |
| "No explicit ModelCheckpoint callback was instantiated from config; " |
| "Lightning default checkpointing behavior will be used." |
| ) |
|
|
| log.info("Instantiating loggers...") |
| logger: List[L.Logger] = instantiate_loggers(cfg.get("logger")) |
|
|
| |
| torch.set_float32_matmul_precision("medium") |
|
|
| |
| for lg in logger: |
| if isinstance(lg, WandbLogger): |
| |
| config_tree_path = Path(cfg.paths.output_dir, "config_tree.log") |
| if config_tree_path.exists(): |
| log.info("Logging config tree to WandB...") |
| lg.experiment.save( |
| str(config_tree_path), policy="now", base_path=cfg.paths.output_dir |
| ) |
|
|
| |
| hydra_dir = Path(cfg.paths.output_dir, ".hydra") |
| if hydra_dir.exists() and hydra_dir.is_dir(): |
| log.info("Logging .hydra folder to WandB...") |
| for hydra_file in hydra_dir.iterdir(): |
| if hydra_file.is_file(): |
| lg.experiment.save( |
| str(hydra_file), |
| policy="now", |
| base_path=cfg.paths.output_dir, |
| ) |
|
|
| log.info(f"Instantiating trainer <{cfg.trainer._target_}>") |
| trainer: L.Trainer = hydra.utils.instantiate( |
| cfg.trainer, |
| callbacks=callbacks, |
| logger=logger, |
| ) |
|
|
| object_dict = { |
| "cfg": cfg, |
| "datamodule": datamodule, |
| "model": model, |
| "callbacks": callbacks, |
| "logger": logger, |
| "trainer": trainer, |
| } |
|
|
| if cfg.get("train"): |
| log.info("Starting training!") |
| trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) |
|
|
| if cfg.get("test"): |
| log.info("Starting testing!") |
| ckpt_path = trainer.checkpoint_callback.best_model_path |
| if ckpt_path == "": |
| log.warning("Best ckpt not found! Using current weights for testing...") |
| ckpt_path = None |
| trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) |
|
|
| return object_dict |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|