| | import os |
| | from functools import partial |
| | from typing import Union, List |
| | from pathlib import Path |
| | from datetime import datetime, timedelta |
| |
|
| | from omegaconf import DictConfig |
| | from pprint import pprint |
| | import torch |
| | from accelerate.utils import LoggerType |
| | from accelerate import ( |
| | Accelerator, |
| | GradScalerKwargs, |
| | DistributedDataParallelKwargs, |
| | InitProcessGroupKwargs |
| | ) |
| |
|
| | from ..modules.ema import EMA |
| | from ..utils.logging import get_logger |
| |
|
| |
|
| | class ModelState: |
| | """ |
| | Handling logger and `hugging face` accelerate training |
| | |
| | features: |
| | - Mixed Precision |
| | - Gradient Scaler |
| | - Gradient Accumulation |
| | - Optimizer |
| | - EMA |
| | - Logger (default: python print) |
| | - Monitor (default: wandb, tensorboard) |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | args, |
| | log_path_suffix: str = None, |
| | ignore_log=False, |
| | ) -> None: |
| | self.args: DictConfig = args |
| |
|
| | """check valid""" |
| | mixed_precision = self.args.get("mixed_precision") |
| | |
| | mixed_precision = "no" if type(mixed_precision) == bool else mixed_precision |
| | split_batches = self.args.get("split_batches", False) |
| | gradient_accumulate_step = self.args.get("gradient_accumulate_step", 1) |
| | assert gradient_accumulate_step >= 1, f"except gradient_accumulate_step >= 1, get {gradient_accumulate_step}" |
| |
|
| | """create working space""" |
| | |
| | |
| | |
| | |
| |
|
| | config_name_only = str(self.args.get("config")).split(".")[0] |
| | results_folder = self.args.get("results_path", None) |
| | if results_folder is None: |
| | |
| | self.results_path = Path("./workdir") |
| | else: |
| | |
| | self.results_path = Path(os.path.join(results_folder, self.args.get("edit_type"), )) |
| |
|
| | |
| | |
| | if log_path_suffix is not None: |
| | self.results_path = self.results_path / log_path_suffix |
| |
|
| | kwargs_handlers = [] |
| | """mixed precision training""" |
| | if args.mixed_precision == "no": |
| | scaler_handler = GradScalerKwargs( |
| | init_scale=args.init_scale, |
| | growth_factor=args.growth_factor, |
| | backoff_factor=args.backoff_factor, |
| | growth_interval=args.growth_interval, |
| | enabled=True |
| | ) |
| | kwargs_handlers.append(scaler_handler) |
| |
|
| | """distributed training""" |
| | ddp_handler = DistributedDataParallelKwargs( |
| | dim=0, |
| | broadcast_buffers=True, |
| | static_graph=False, |
| | bucket_cap_mb=25, |
| | find_unused_parameters=False, |
| | check_reduction=False, |
| | gradient_as_bucket_view=False |
| | ) |
| | kwargs_handlers.append(ddp_handler) |
| |
|
| | init_handler = InitProcessGroupKwargs(timeout=timedelta(seconds=1200)) |
| | kwargs_handlers.append(init_handler) |
| |
|
| | """init visualized tracker""" |
| | log_with = [] |
| | self.args.visual = False |
| | if args.use_wandb: |
| | log_with.append(LoggerType.WANDB) |
| | if args.tensorboard: |
| | log_with.append(LoggerType.TENSORBOARD) |
| |
|
| | """hugging face Accelerator""" |
| | self.accelerator = Accelerator( |
| | device_placement=True, |
| | split_batches=split_batches, |
| | mixed_precision=mixed_precision, |
| | gradient_accumulation_steps=args.gradient_accumulate_step, |
| | cpu=True if args.use_cpu else False, |
| | log_with=None if len(log_with) == 0 else log_with, |
| | project_dir=self.results_path / "vis", |
| | kwargs_handlers=kwargs_handlers, |
| | ) |
| |
|
| | """logs""" |
| | if self.accelerator.is_local_main_process: |
| | |
| | self.results_path.mkdir(parents=True, exist_ok=True) |
| | if not ignore_log: |
| | now_time = datetime.now().strftime('%Y-%m-%d-%H-%M') |
| | |
| | |
| | |
| | |
| |
|
| | print("==> command line args: ") |
| | print(args.cmd_args) |
| | print("==> yaml config args: ") |
| | print(args.yaml_config) |
| |
|
| | print("\n***** Model State *****") |
| | if self.accelerator.distributed_type != "NO": |
| | print(f"-> Distributed Type: {self.accelerator.distributed_type}") |
| | print(f"-> Split Batch Size: {split_batches}, Total Batch Size: {self.actual_batch_size}") |
| | print(f"-> Mixed Precision: {mixed_precision}, AMP: {self.accelerator.native_amp}," |
| | f" Gradient Accumulate Step: {gradient_accumulate_step}") |
| | print(f"-> Weight dtype: {self.weight_dtype}") |
| |
|
| | if self.accelerator.scaler_handler is not None and self.accelerator.scaler_handler.enabled: |
| | print(f"-> Enabled GradScaler: {self.accelerator.scaler_handler.to_kwargs()}") |
| |
|
| | if args.use_wandb: |
| | print(f"-> Init trackers: 'wandb' ") |
| | self.args.visual = True |
| | self.__init_tracker(project_name="my_project", tags=None, entity="") |
| |
|
| | print(f"-> Working Space: '{self.results_path}'") |
| |
|
| | """EMA""" |
| | self.use_ema = args.get('ema', False) |
| | self.ema_wrapper = self.__build_ema_wrapper() |
| |
|
| | """glob step""" |
| | self.step = 0 |
| |
|
| | """log process""" |
| | self.accelerator.wait_for_everyone() |
| | print(f'Process {self.accelerator.process_index} using device: {self.accelerator.device}') |
| |
|
| | self.print("-> state initialization complete \n") |
| |
|
| | def __init_tracker(self, project_name, tags, entity): |
| | self.accelerator.init_trackers( |
| | project_name=project_name, |
| | config=dict(self.args), |
| | init_kwargs={ |
| | "wandb": { |
| | "notes": "accelerate trainer pipeline", |
| | "tags": [ |
| | f"total batch_size: {self.actual_batch_size}" |
| | ], |
| | "entity": entity, |
| | }} |
| | ) |
| |
|
| | def __build_ema_wrapper(self): |
| | if self.use_ema: |
| | self.print(f"-> EMA: {self.use_ema}, decay: {self.args.ema_decay}, " |
| | f"update_after_step: {self.args.ema_update_after_step}, " |
| | f"update_every: {self.args.ema_update_every}") |
| | ema_wrapper = partial( |
| | EMA, beta=self.args.ema_decay, |
| | update_after_step=self.args.ema_update_after_step, |
| | update_every=self.args.ema_update_every |
| | ) |
| | else: |
| | ema_wrapper = None |
| |
|
| | return ema_wrapper |
| |
|
| | @property |
| | def device(self): |
| | return self.accelerator.device |
| |
|
| | @property |
| | def weight_dtype(self): |
| | weight_dtype = torch.float32 |
| | if self.accelerator.mixed_precision == "fp16": |
| | weight_dtype = torch.float16 |
| | elif self.accelerator.mixed_precision == "bf16": |
| | weight_dtype = torch.bfloat16 |
| | return weight_dtype |
| |
|
| | @property |
| | def actual_batch_size(self): |
| | if self.accelerator.split_batches is False: |
| | actual_batch_size = self.args.batch_size * self.accelerator.num_processes * self.accelerator.gradient_accumulation_steps |
| | else: |
| | assert self.actual_batch_size % self.accelerator.num_processes == 0 |
| | actual_batch_size = self.args.batch_size |
| | return actual_batch_size |
| |
|
| | @property |
| | def n_gpus(self): |
| | return self.accelerator.num_processes |
| |
|
| | @property |
| | def no_decay_params_names(self): |
| | no_decay = [ |
| | "bn", "LayerNorm", "GroupNorm", |
| | ] |
| | return no_decay |
| |
|
| | def no_decay_params(self, model, weight_decay): |
| | """optimization tricks""" |
| | optimizer_grouped_parameters = [ |
| | { |
| | "params": [ |
| | p for n, p in model.named_parameters() |
| | if not any(nd in n for nd in self.no_decay_params_names) |
| | ], |
| | "weight_decay": weight_decay, |
| | }, |
| | { |
| | "params": [ |
| | p for n, p in model.named_parameters() |
| | if any(nd in n for nd in self.no_decay_params_names) |
| | ], |
| | "weight_decay": 0.0, |
| | }, |
| | ] |
| | return optimizer_grouped_parameters |
| |
|
| | def optimized_params(self, model: torch.nn.Module, verbose=True) -> List: |
| | """return parameters if `requires_grad` is True |
| | |
| | Args: |
| | model: pytorch models |
| | verbose: log optimized parameters |
| | |
| | Examples: |
| | >>> self.params_optimized = self.optimized_params(uvit, verbose=True) |
| | >>> optimizer = torch.optim.AdamW(self.params_optimized, lr=args.lr) |
| | |
| | Returns: |
| | a list of parameters |
| | """ |
| | params_optimized = [] |
| | for key, value in model.named_parameters(): |
| | if value.requires_grad: |
| | params_optimized.append(value) |
| | if verbose: |
| | self.print("\t {}, {}, {}".format(key, value.numel(), value.shape)) |
| | return params_optimized |
| |
|
| | def save_everything(self, fpath: str): |
| | """Saving and loading the model, optimizer, RNG generators, and the GradScaler.""" |
| | if not self.accelerator.is_main_process: |
| | return |
| | self.accelerator.save_state(fpath) |
| |
|
| | def load_save_everything(self, fpath: str): |
| | """Loading the model, optimizer, RNG generators, and the GradScaler.""" |
| | self.accelerator.load_state(fpath) |
| |
|
| | def save(self, milestone: Union[str, float, int], checkpoint: object) -> None: |
| | if not self.accelerator.is_main_process: |
| | return |
| |
|
| | torch.save(checkpoint, self.results_path / f'model-{milestone}.pt') |
| |
|
| | def save_in(self, root: Union[str, Path], checkpoint: object) -> None: |
| | if not self.accelerator.is_main_process: |
| | return |
| |
|
| | torch.save(checkpoint, root) |
| |
|
| | def load_ckpt_model_only(self, model: torch.nn.Module, path: Union[str, Path], rm_module_prefix: bool = False): |
| | ckpt = torch.load(path, map_location=self.accelerator.device) |
| |
|
| | unwrapped_model = self.accelerator.unwrap_model(model) |
| | if rm_module_prefix: |
| | unwrapped_model.load_state_dict({k.replace('module.', ''): v for k, v in ckpt.items()}) |
| | else: |
| | unwrapped_model.load_state_dict(ckpt) |
| | return unwrapped_model |
| |
|
| | def load_shared_weights(self, model: torch.nn.Module, path: Union[str, Path]): |
| | ckpt = torch.load(path, map_location=self.accelerator.device) |
| | self.print(f"pretrained_dict len: {len(ckpt)}") |
| | unwrapped_model = self.accelerator.unwrap_model(model) |
| | model_dict = unwrapped_model.state_dict() |
| | pretrained_dict = {k: v for k, v in ckpt.items() if k in model_dict} |
| | model_dict.update(pretrained_dict) |
| | unwrapped_model.load_state_dict(model_dict, strict=False) |
| | self.print(f"selected pretrained_dict: {len(model_dict)}") |
| | return unwrapped_model |
| |
|
| | def print(self, *args, **kwargs): |
| | """Use in replacement of `print()` to only print once per server.""" |
| | self.accelerator.print(*args, **kwargs) |
| |
|
| | def pretty_print(self, msg): |
| | if self.accelerator.is_local_main_process: |
| | pprint(dict(msg)) |
| |
|
| | def close_tracker(self): |
| | self.accelerator.end_training() |
| |
|
| | def free_memory(self): |
| | self.accelerator.clear() |
| |
|
| | def close(self, msg: str = "Training complete."): |
| | """Use in end of training.""" |
| | self.free_memory() |
| |
|
| | if torch.cuda.is_available(): |
| | self.print(f'\nGPU memory usage: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB') |
| | if self.args.visual: |
| | self.close_tracker() |
| | self.print(msg) |
| |
|