| | |
| | import torch |
| | from os.path import join |
| | import torch.distributed as dist |
| | from .utilities import check_makedirs |
| | from collections import OrderedDict |
| | from torch.nn.parallel import DataParallel, DistributedDataParallel |
| |
|
| |
|
| | def step_learning_rate(base_lr, epoch, step_epoch, multiplier=0.1): |
| | lr = base_lr * (multiplier ** (epoch // step_epoch)) |
| | return lr |
| |
|
| |
|
| | def poly_learning_rate(base_lr, curr_iter, max_iter, power=0.9): |
| | """poly learning rate policy""" |
| | lr = base_lr * (1 - float(curr_iter) / max_iter) ** power |
| | return lr |
| |
|
| |
|
| | def adjust_learning_rate(optimizer, lr): |
| | for param_group in optimizer.param_groups: |
| | param_group['lr'] = lr |
| |
|
| |
|
| | def save_checkpoint(model, other_state={}, sav_path='', filename='model.pth.tar', stage=1): |
| | if isinstance(model, (DistributedDataParallel, DataParallel)): |
| | weight = model.module.state_dict() |
| | elif isinstance(model, torch.nn.Module): |
| | weight = model.state_dict() |
| | else: |
| | raise ValueError('model must be nn.Module or nn.DataParallel!') |
| | check_makedirs(sav_path) |
| |
|
| | if stage == 2: |
| | for k in list(weight.keys()): |
| | if 'autoencoder' in k: |
| | weight.pop(k) |
| |
|
| | other_state['state_dict'] = weight |
| | filename = join(sav_path, filename) |
| | torch.save(other_state, filename) |
| |
|
| |
|
| |
|
| | def load_state_dict(model, state_dict, strict=True): |
| | if isinstance(model, (DistributedDataParallel, DataParallel)): |
| | model.module.load_state_dict(state_dict, strict=strict) |
| | else: |
| | model.load_state_dict(state_dict, strict=strict) |
| |
|
| |
|
| | def state_dict_remove_module(state_dict): |
| | new_state_dict = OrderedDict() |
| | for k, v in state_dict.items(): |
| | |
| | name = k.replace('module.', '') |
| | new_state_dict[name] = v |
| | return new_state_dict |
| |
|
| |
|
| | def reduce_tensor(tensor, args): |
| | rt = tensor.clone() |
| | dist.all_reduce(rt, op=dist.ReduceOp.SUM) |
| | rt /= args.world_size |
| | return rt |
| |
|