| import argparse |
| import contextlib |
| import importlib |
| import logging |
| import os |
| import sys |
| import time |
| import traceback |
| import pytorch_lightning as pl |
| import torch |
| from pytorch_lightning import Trainer |
| from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint |
| from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger |
| from pytorch_lightning.utilities.rank_zero import rank_zero_only |
| import craftsman |
| from craftsman.systems.base import BaseSystem |
| from craftsman.utils.callbacks import ( |
| CodeSnapshotCallback, |
| ConfigSnapshotCallback, |
| CustomProgressBar, |
| ProgressCallback, |
| ) |
| from craftsman.utils.config import ExperimentConfig, load_config |
| from craftsman.utils.misc import get_rank |
| from craftsman.utils.typing import Optional |
| class ColoredFilter(logging.Filter): |
| """ |
| A logging filter to add color to certain log levels. |
| """ |
|
|
| RESET = "\033[0m" |
| RED = "\033[31m" |
| GREEN = "\033[32m" |
| YELLOW = "\033[33m" |
| BLUE = "\033[34m" |
| MAGENTA = "\033[35m" |
| CYAN = "\033[36m" |
|
|
| COLORS = { |
| "WARNING": YELLOW, |
| "INFO": GREEN, |
| "DEBUG": BLUE, |
| "CRITICAL": MAGENTA, |
| "ERROR": RED, |
| } |
|
|
| RESET = "\x1b[0m" |
|
|
| def __init__(self): |
| super().__init__() |
|
|
| def filter(self, record): |
| if record.levelname in self.COLORS: |
| color_start = self.COLORS[record.levelname] |
| record.levelname = f"{color_start}[{record.levelname}]" |
| record.msg = f"{record.msg}{self.RESET}" |
| return True |
|
|
|
|
| def load_custom_module(module_path): |
| module_name = os.path.basename(module_path) |
| if os.path.isfile(module_path): |
| sp = os.path.splitext(module_path) |
| module_name = sp[0] |
| try: |
| if os.path.isfile(module_path): |
| module_spec = importlib.util.spec_from_file_location( |
| module_name, module_path |
| ) |
| else: |
| module_spec = importlib.util.spec_from_file_location( |
| module_name, os.path.join(module_path, "__init__.py") |
| ) |
|
|
| module = importlib.util.module_from_spec(module_spec) |
| sys.modules[module_name] = module |
| module_spec.loader.exec_module(module) |
| return True |
| except Exception as e: |
| print(traceback.format_exc()) |
| print(f"Cannot import {module_path} module for custom nodes:", e) |
| return False |
|
|
|
|
| def load_custom_modules(): |
| node_paths = ["custom"] |
| node_import_times = [] |
| if not os.path.exists("node_paths"): |
| return |
| for custom_node_path in node_paths: |
| possible_modules = os.listdir(custom_node_path) |
| if "__pycache__" in possible_modules: |
| possible_modules.remove("__pycache__") |
|
|
| for possible_module in possible_modules: |
| module_path = os.path.join(custom_node_path, possible_module) |
| if ( |
| os.path.isfile(module_path) |
| and os.path.splitext(module_path)[1] != ".py" |
| ): |
| continue |
| if module_path.endswith(".disabled"): |
| continue |
| time_before = time.perf_counter() |
| success = load_custom_module(module_path) |
| node_import_times.append( |
| (time.perf_counter() - time_before, module_path, success) |
| ) |
|
|
| if len(node_import_times) > 0: |
| print("\nImport times for custom modules:") |
| for n in sorted(node_import_times): |
| if n[2]: |
| import_message = "" |
| else: |
| import_message = " (IMPORT FAILED)" |
| print("{:6.1f} seconds{}:".format(n[0], import_message), n[1]) |
| print() |
|
|
|
|
| def main(args, extras) -> None: |
| |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| env_gpus_str = os.environ.get("CUDA_VISIBLE_DEVICES", None) |
| env_gpus = list(env_gpus_str.split(",")) if env_gpus_str else [] |
| selected_gpus = [0] |
| torch.set_float32_matmul_precision("high") |
|
|
| |
| |
| |
| devices = -1 |
| if len(env_gpus) > 0: |
| n_gpus = len(env_gpus) |
| else: |
| selected_gpus = list(args.gpu.split(",")) |
| n_gpus = len(selected_gpus) |
| print(f"Using {n_gpus} GPUs: {selected_gpus}") |
| os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu |
|
|
| if args.typecheck: |
| from jaxtyping import install_import_hook |
|
|
| install_import_hook("craftsman", "typeguard.typechecked") |
|
|
| logger = logging.getLogger("pytorch_lightning") |
| if args.verbose: |
| logger.setLevel(logging.DEBUG) |
|
|
| for handler in logger.handlers: |
| if handler.stream == sys.stderr: |
| if not args.gradio: |
| handler.setFormatter(logging.Formatter("%(levelname)s %(message)s")) |
| handler.addFilter(ColoredFilter()) |
| else: |
| handler.setFormatter(logging.Formatter("[%(levelname)s] %(message)s")) |
|
|
| load_custom_modules() |
|
|
| |
| cfg: ExperimentConfig |
| cfg = load_config(args.config, cli_args=extras, n_gpus=n_gpus) |
|
|
| |
| pl.seed_everything(cfg.seed + get_rank(), workers=True) |
|
|
| dm = craftsman.find(cfg.data_type)(cfg.data) |
| system: BaseSystem = craftsman.find(cfg.system_type)( |
| cfg.system, resumed=cfg.resume is not None |
| ) |
| system.set_save_dir(os.path.join(cfg.trial_dir, "save")) |
|
|
| if args.gradio: |
| fh = logging.FileHandler(os.path.join(cfg.trial_dir, "logs")) |
| fh.setLevel(logging.INFO) |
| if args.verbose: |
| fh.setLevel(logging.DEBUG) |
| fh.setFormatter(logging.Formatter("[%(levelname)s] %(message)s")) |
| logger.addHandler(fh) |
|
|
| callbacks = [] |
| if args.train: |
| callbacks += [ |
| ModelCheckpoint( |
| dirpath=os.path.join(cfg.trial_dir, "ckpts"), **cfg.checkpoint |
| ), |
| LearningRateMonitor(logging_interval="step"), |
| CodeSnapshotCallback( |
| os.path.join(cfg.trial_dir, "code"), use_version=False |
| ), |
| ConfigSnapshotCallback( |
| args.config, |
| cfg, |
| os.path.join(cfg.trial_dir, "configs"), |
| use_version=False, |
| ), |
| ] |
| if args.gradio: |
| callbacks += [ |
| ProgressCallback(save_path=os.path.join(cfg.trial_dir, "progress")) |
| ] |
| else: |
| callbacks += [CustomProgressBar(refresh_rate=1)] |
|
|
| def write_to_text(file, lines): |
| with open(file, "w") as f: |
| for line in lines: |
| f.write(line + "\n") |
|
|
| loggers = [] |
| if args.train: |
| |
| rank_zero_only( |
| lambda: os.makedirs(os.path.join(cfg.trial_dir, "tb_logs"), exist_ok=True) |
| )() |
| loggers += [ |
| TensorBoardLogger(cfg.trial_dir, name="tb_logs"), |
| CSVLogger(cfg.trial_dir, name="csv_logs"), |
| ] + system.get_loggers() |
| rank_zero_only( |
| lambda: write_to_text( |
| os.path.join(cfg.trial_dir, "cmd.txt"), |
| ["python " + " ".join(sys.argv), str(args)], |
| ) |
| )() |
|
|
| trainer = Trainer( |
| callbacks=callbacks, |
| logger=loggers, |
| inference_mode=False, |
| accelerator="gpu", |
| devices=devices, |
| |
| **cfg.trainer, |
| ) |
|
|
| def set_system_status(system: BaseSystem, ckpt_path: Optional[str]): |
| if ckpt_path is None: |
| return |
| ckpt = torch.load(ckpt_path, map_location="cpu") |
| system.set_resume_status(ckpt["epoch"], ckpt["global_step"]) |
| if args.train: |
| trainer.fit(system, datamodule=dm, ckpt_path=cfg.resume) |
| trainer.test(system, datamodule=dm) |
| if args.gradio: |
| |
| trainer.predict(system, datamodule=dm) |
| elif args.validate: |
| |
| set_system_status(system, cfg.resume) |
| trainer.validate(system, datamodule=dm, ckpt_path=cfg.resume) |
| elif args.test: |
| |
| set_system_status(system, cfg.resume) |
| trainer.test(system, datamodule=dm, ckpt_path=cfg.resume) |
| elif args.export: |
| set_system_status(system, cfg.resume) |
| trainer.predict(system, datamodule=dm, ckpt_path=cfg.resume) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config", required=True, help="path to config file") |
| parser.add_argument( |
| "--gpu", |
| default="0", |
| help="GPU(s) to be used. 0 means use the 1st available GPU. " |
| "1,2 means use the 2nd and 3rd available GPU. " |
| "If CUDA_VISIBLE_DEVICES is set before calling `launch.py`, " |
| "this argument is ignored and all available GPUs are always used.", |
| ) |
|
|
| group = parser.add_mutually_exclusive_group(required=True) |
| group.add_argument("--train", action="store_true") |
| group.add_argument("--validate", action="store_true") |
| group.add_argument("--test", action="store_true") |
| group.add_argument("--export", action="store_true") |
|
|
| parser.add_argument( |
| "--gradio", action="store_true", help="if true, run in gradio mode" |
| ) |
|
|
| parser.add_argument( |
| "--verbose", action="store_true", help="if true, set logging level to DEBUG" |
| ) |
|
|
| parser.add_argument( |
| "--typecheck", |
| action="store_true", |
| help="whether to enable dynamic type checking", |
| ) |
|
|
| args, extras = parser.parse_known_args() |
|
|
| if args.gradio: |
| with contextlib.redirect_stdout(sys.stderr): |
| main(args, extras) |
| else: |
| main(args, extras) |
|
|