| import torch |
| from collections import OrderedDict |
| from diffusers.utils import logging |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| def get_fsdp_plugin(fsdp_cfg, mixed_precision): |
| import functools |
| from torch.distributed.fsdp.fully_sharded_data_parallel import ( |
| BackwardPrefetch, |
| CPUOffload, |
| ShardingStrategy, |
| MixedPrecision, |
| StateDictType, |
| FullStateDictConfig, |
| FullOptimStateDictConfig, |
| ) |
| from accelerate.utils import FullyShardedDataParallelPlugin |
| from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy |
|
|
| if mixed_precision == "fp16": |
| dtype = torch.float16 |
| elif mixed_precision == "bf16": |
| dtype = torch.bfloat16 |
| else: |
| dtype = torch.float32 |
| fsdp_plugin = FullyShardedDataParallelPlugin( |
| sharding_strategy={ |
| "FULL_SHARD": ShardingStrategy.FULL_SHARD, |
| "SHARD_GRAD_OP": ShardingStrategy.SHARD_GRAD_OP, |
| "NO_SHARD": ShardingStrategy.NO_SHARD, |
| "HYBRID_SHARD": ShardingStrategy.HYBRID_SHARD, |
| "HYBRID_SHARD_ZERO2": ShardingStrategy._HYBRID_SHARD_ZERO2, |
| }[fsdp_cfg.sharding_strategy], |
| backward_prefetch={ |
| "BACKWARD_PRE": BackwardPrefetch.BACKWARD_PRE, |
| "BACKWARD_POST": BackwardPrefetch.BACKWARD_POST, |
| }[fsdp_cfg.backward_prefetch], |
| mixed_precision_policy=MixedPrecision( |
| param_dtype=dtype, |
| reduce_dtype=dtype, |
| ), |
| auto_wrap_policy=functools.partial( |
| size_based_auto_wrap_policy, min_num_params=fsdp_cfg.min_num_params |
| ), |
| cpu_offload=CPUOffload(offload_params=fsdp_cfg.cpu_offload), |
| state_dict_type={ |
| "FULL_STATE_DICT": StateDictType.FULL_STATE_DICT, |
| "LOCAL_STATE_DICT": StateDictType.LOCAL_STATE_DICT, |
| "SHARDED_STATE_DICT": StateDictType.SHARDED_STATE_DICT, |
| }[fsdp_cfg.state_dict_type], |
| state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), |
| optim_state_dict_config=FullOptimStateDictConfig( |
| offload_to_cpu=True, rank0_only=True |
| ), |
| limit_all_gathers=fsdp_cfg.limit_all_gathers, |
| use_orig_params=fsdp_cfg.use_orig_params, |
| sync_module_states=fsdp_cfg.sync_module_states, |
| forward_prefetch=fsdp_cfg.forward_prefetch, |
| activation_checkpointing=fsdp_cfg.activation_checkpointing, |
| ) |
| return fsdp_plugin |
|
|
|
|
| def freeze_model(model, trainable_modules={}, verbose=False): |
| logger.info("Start freeze") |
| for name, param in model.named_parameters(): |
| |
| if verbose: |
| logger.info("freeze moduel: " + str(name)) |
| for trainable_module_name in trainable_modules: |
| if trainable_module_name in name: |
| |
| if verbose: |
| logger.info("unfreeze moduel: " + str(name)) |
| break |
| logger.info("End freeze") |
| |
| |
| |
| |
| return |
|
|
|
|
| @torch.no_grad() |
| def update_ema(ema_model, model, decay=0.9999): |
| """ |
| Step the EMA model towards the current model. |
| """ |
| if hasattr(model, "module"): |
| model = model.module |
| if hasattr(ema_model, "module"): |
| ema_model = ema_model.module |
| ema_params = OrderedDict(ema_model.named_parameters()) |
| model_params = OrderedDict(model.named_parameters()) |
|
|
| for name, param in model_params.items(): |
| |
| ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) |
|
|
|
|
| def log_validation(model): |
| pass |
|
|