File size: 4,217 Bytes
eca55dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import rootutils
import hydra
from omegaconf import DictConfig
import lightning as L
import torch
from pathlib import Path
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from typing import List, Dict, Any

# Setup root
root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)

from src.utils import instantiate_callbacks, instantiate_loggers, RankedLogger, extras  # noqa: E402

log = RankedLogger(__name__, rank_zero_only=True)


@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
def main(cfg: DictConfig) -> Dict[str, Any]:
    # Set seed
    if cfg.get("seed"):
        L.seed_everything(cfg.seed, workers=True)

    # Applies optional utilities
    extras(cfg)

    log.info(f"Instantiating datamodule <{cfg.data._target_}>")
    datamodule: L.LightningDataModule = hydra.utils.instantiate(cfg.data)

    log.info(f"Instantiating model <{cfg.model._target_}>")
    model: L.LightningModule = hydra.utils.instantiate(cfg.model)

    log.info("Instantiating callbacks...")
    callbacks: List[L.Callback] = instantiate_callbacks(cfg.get("callbacks"))

    callbacks_cfg = cfg.get("callbacks")
    if (
        isinstance(callbacks_cfg, DictConfig)
        and "model_checkpoint" in callbacks_cfg
        and callbacks_cfg.model_checkpoint is None
    ):
        log.warning(
            "`callbacks.model_checkpoint` is null in the composed config. "
            "Lightning will use its default ModelCheckpoint callback, which may not "
            "save `last.ckpt` and can change filename conventions. Remove the null "
            "override or set explicit checkpoint fields in the experiment config."
        )

    if cfg.get("train") and not any(
        isinstance(callback, ModelCheckpoint) for callback in callbacks
    ):
        log.warning(
            "No explicit ModelCheckpoint callback was instantiated from config; "
            "Lightning default checkpointing behavior will be used."
        )

    log.info("Instantiating loggers...")
    logger: List[L.Logger] = instantiate_loggers(cfg.get("logger"))

    # Set float32 matmul precision for Tensor Cores
    torch.set_float32_matmul_precision("medium")

    # Log config tree and .hydra folder to wandb
    for lg in logger:
        if isinstance(lg, WandbLogger):
            # check if config_tree.log exists
            config_tree_path = Path(cfg.paths.output_dir, "config_tree.log")
            if config_tree_path.exists():
                log.info("Logging config tree to WandB...")
                lg.experiment.save(
                    str(config_tree_path), policy="now", base_path=cfg.paths.output_dir
                )

            # Upload .hydra folder contents
            hydra_dir = Path(cfg.paths.output_dir, ".hydra")
            if hydra_dir.exists() and hydra_dir.is_dir():
                log.info("Logging .hydra folder to WandB...")
                for hydra_file in hydra_dir.iterdir():
                    if hydra_file.is_file():
                        lg.experiment.save(
                            str(hydra_file),
                            policy="now",
                            base_path=cfg.paths.output_dir,
                        )

    log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
    trainer: L.Trainer = hydra.utils.instantiate(
        cfg.trainer,
        callbacks=callbacks,
        logger=logger,
    )

    object_dict = {
        "cfg": cfg,
        "datamodule": datamodule,
        "model": model,
        "callbacks": callbacks,
        "logger": logger,
        "trainer": trainer,
    }

    if cfg.get("train"):
        log.info("Starting training!")
        trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))

    if cfg.get("test"):
        log.info("Starting testing!")
        ckpt_path = trainer.checkpoint_callback.best_model_path
        if ckpt_path == "":
            log.warning("Best ckpt not found! Using current weights for testing...")
            ckpt_path = None
        trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)

    return object_dict


if __name__ == "__main__":
    main()