| | import logging |
| | import os |
| | import torch |
| | import torch.distributed as dist |
| | import yaml |
| |
|
| | from fvcore.nn import FlopCountAnalysis |
| | from fvcore.nn import flop_count_table |
| | from fvcore.nn import flop_count_str |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | NORM_MODULES = [ |
| | torch.nn.BatchNorm1d, |
| | torch.nn.BatchNorm2d, |
| | torch.nn.BatchNorm3d, |
| | torch.nn.SyncBatchNorm, |
| | |
| | torch.nn.GroupNorm, |
| | torch.nn.InstanceNorm1d, |
| | torch.nn.InstanceNorm2d, |
| | torch.nn.InstanceNorm3d, |
| | torch.nn.LayerNorm, |
| | torch.nn.LocalResponseNorm, |
| | ] |
| |
|
| | def register_norm_module(cls): |
| | NORM_MODULES.append(cls) |
| |
|
| | return cls |
| |
|
| |
|
| | def is_main_process(): |
| | rank = 0 |
| | if 'OMPI_COMM_WORLD_SIZE' in os.environ: |
| | rank = int(os.environ['OMPI_COMM_WORLD_RANK']) |
| |
|
| | return rank == 0 |
| |
|
| |
|
| | @torch.no_grad() |
| | def analysis_model(model, dump_input, verbose=False): |
| | model.eval() |
| | flops = FlopCountAnalysis(model, dump_input) |
| | total = flops.total() |
| | model.train() |
| | params_total = sum(p.numel() for p in model.parameters()) |
| | params_learned = sum( |
| | p.numel() for p in model.parameters() if p.requires_grad |
| | ) |
| | logger.info(f"flop count table:\n {flop_count_table(flops)}") |
| | if verbose: |
| | logger.info(f"flop count str:\n {flop_count_str(flops)}") |
| | logger.info(f" Total flops: {total/1000/1000:.3f}M,") |
| | logger.info(f" Total params: {params_total/1000/1000:.3f}M,") |
| | logger.info(f" Learned params: {params_learned/1000/1000:.3f}M") |
| |
|
| | return total, flop_count_table(flops), flop_count_str(flops) |
| |
|
| |
|
| | def load_config_dict_to_opt(opt, config_dict, splitter='.'): |
| | """ |
| | Load the key, value pairs from config_dict to opt, overriding existing values in opt |
| | if there is any. |
| | """ |
| | if not isinstance(config_dict, dict): |
| | raise TypeError("Config must be a Python dictionary") |
| | for k, v in config_dict.items(): |
| | k_parts = k.split(splitter) |
| | pointer = opt |
| | for k_part in k_parts[:-1]: |
| | if k_part not in pointer: |
| | pointer[k_part] = {} |
| | pointer = pointer[k_part] |
| | assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict." |
| | ori_value = pointer.get(k_parts[-1]) |
| | pointer[k_parts[-1]] = v |
| | if ori_value: |
| | print(f"Overrided {k} from {ori_value} to {pointer[k_parts[-1]]}") |
| |
|
| |
|
| | def load_opt_from_config_file(conf_file): |
| | """ |
| | Load opt from the config file. |
| | |
| | Args: |
| | conf_file: config file path |
| | |
| | Returns: |
| | dict: a dictionary of opt settings |
| | """ |
| | opt = {} |
| | with open(conf_file, encoding='utf-8') as f: |
| | config_dict = yaml.safe_load(f) |
| | load_config_dict_to_opt(opt, config_dict) |
| |
|
| | return opt |
| |
|
| | def cast_batch_to_dtype(batch, dtype): |
| | """ |
| | Cast the float32 tensors in a batch to a specified torch dtype. |
| | It should be called before feeding the batch to the FP16 DeepSpeed model. |
| | |
| | Args: |
| | batch (torch.tensor or container of torch.tensor): input batch |
| | Returns: |
| | return_batch: same type as the input batch with internal float32 tensors casted to the specified dtype. |
| | """ |
| | if torch.is_tensor(batch): |
| | if torch.is_floating_point(batch): |
| | return_batch = batch.to(dtype) |
| | else: |
| | return_batch = batch |
| | elif isinstance(batch, list): |
| | return_batch = [cast_batch_to_dtype(t, dtype) for t in batch] |
| | elif isinstance(batch, tuple): |
| | return_batch = tuple(cast_batch_to_dtype(t, dtype) for t in batch) |
| | elif isinstance(batch, dict): |
| | return_batch = {} |
| | for k in batch: |
| | return_batch[k] = cast_batch_to_dtype(batch[k], dtype) |
| | else: |
| | logger.debug(f"Can not cast type {type(batch)} to {dtype}. Skipping it in the batch.") |
| | return_batch = batch |
| |
|
| | return return_batch |
| |
|
| |
|
| | def cast_batch_to_half(batch): |
| | """ |
| | Cast the float32 tensors in a batch to float16. |
| | It should be called before feeding the batch to the FP16 DeepSpeed model. |
| | |
| | Args: |
| | batch (torch.tensor or container of torch.tensor): input batch |
| | Returns: |
| | return_batch: same type as the input batch with internal float32 tensors casted to float16 |
| | """ |
| | return cast_batch_to_dtype(batch, torch.float16) |
| |
|