| | """ |
| | Helpers to train with 16-bit precision. |
| | """ |
| |
|
| | import numpy as np |
| | import torch as th |
| | import torch.nn as nn |
| | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors |
| |
|
| | from . import logger |
| |
|
| | INITIAL_LOG_LOSS_SCALE = 20.0 |
| |
|
| |
|
| | def convert_module_to_f16(l): |
| | """ |
| | Convert primitive modules to float16. |
| | """ |
| | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): |
| | l.weight.data = l.weight.data.half() |
| | if l.bias is not None: |
| | l.bias.data = l.bias.data.half() |
| |
|
| |
|
| | def convert_module_to_f32(l): |
| | """ |
| | Convert primitive modules to float32, undoing convert_module_to_f16(). |
| | """ |
| | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): |
| | l.weight.data = l.weight.data.float() |
| | if l.bias is not None: |
| | l.bias.data = l.bias.data.float() |
| |
|
| |
|
| | def make_master_params(param_groups_and_shapes): |
| | """ |
| | Copy model parameters into a (differently-shaped) list of full-precision |
| | parameters. |
| | """ |
| | master_params = [] |
| | for param_group, shape in param_groups_and_shapes: |
| | master_param = nn.Parameter( |
| | _flatten_dense_tensors( |
| | [param.detach().float() for (_, param) in param_group] |
| | ).view(shape) |
| | ) |
| | master_param.requires_grad = True |
| | master_params.append(master_param) |
| | return master_params |
| |
|
| |
|
| | def model_grads_to_master_grads(param_groups_and_shapes, master_params): |
| | """ |
| | Copy the gradients from the model parameters into the master parameters |
| | from make_master_params(). |
| | """ |
| | for master_param, (param_group, shape) in zip( |
| | master_params, param_groups_and_shapes |
| | ): |
| | master_param.grad = _flatten_dense_tensors( |
| | [param_grad_or_zeros(param) for (_, param) in param_group] |
| | ).view(shape) |
| |
|
| |
|
| | def master_params_to_model_params(param_groups_and_shapes, master_params): |
| | """ |
| | Copy the master parameter data back into the model parameters. |
| | """ |
| | |
| | |
| | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): |
| | for (_, param), unflat_master_param in zip( |
| | param_group, unflatten_master_params(param_group, master_param.view(-1)) |
| | ): |
| | param.detach().copy_(unflat_master_param) |
| |
|
| |
|
| | def unflatten_master_params(param_group, master_param): |
| | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) |
| |
|
| |
|
| | def get_param_groups_and_shapes(named_model_params): |
| | named_model_params = list(named_model_params) |
| | scalar_vector_named_params = ( |
| | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], |
| | (-1), |
| | ) |
| | matrix_named_params = ( |
| | [(n, p) for (n, p) in named_model_params if p.ndim > 1], |
| | (1, -1), |
| | ) |
| | return [scalar_vector_named_params, matrix_named_params] |
| |
|
| |
|
| | def master_params_to_state_dict( |
| | model, param_groups_and_shapes, master_params, use_fp16 |
| | ): |
| | if use_fp16: |
| | state_dict = model.state_dict() |
| | for master_param, (param_group, _) in zip( |
| | master_params, param_groups_and_shapes |
| | ): |
| | for (name, _), unflat_master_param in zip( |
| | param_group, unflatten_master_params(param_group, master_param.view(-1)) |
| | ): |
| | assert name in state_dict |
| | state_dict[name] = unflat_master_param |
| | else: |
| | state_dict = model.state_dict() |
| | for i, (name, _value) in enumerate(model.named_parameters()): |
| | assert name in state_dict |
| | state_dict[name] = master_params[i] |
| | return state_dict |
| |
|
| |
|
| | def state_dict_to_master_params(model, state_dict, use_fp16): |
| | if use_fp16: |
| | named_model_params = [ |
| | (name, state_dict[name]) for name, _ in model.named_parameters() |
| | ] |
| | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) |
| | master_params = make_master_params(param_groups_and_shapes) |
| | else: |
| | master_params = [state_dict[name] for name, _ in model.named_parameters()] |
| | return master_params |
| |
|
| |
|
| | def zero_master_grads(master_params): |
| | for param in master_params: |
| | param.grad = None |
| |
|
| |
|
| | def zero_grad(model_params): |
| | for param in model_params: |
| | |
| | if param.grad is not None: |
| | param.grad.detach_() |
| | param.grad.zero_() |
| |
|
| |
|
| | def param_grad_or_zeros(param): |
| | if param.grad is not None: |
| | return param.grad.data.detach() |
| | else: |
| | return th.zeros_like(param) |
| |
|
| |
|
| | class MixedPrecisionTrainer: |
| | def __init__( |
| | self, |
| | *, |
| | model, |
| | use_fp16=False, |
| | fp16_scale_growth=1e-3, |
| | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, |
| | ): |
| | self.model = model |
| | self.use_fp16 = use_fp16 |
| | self.fp16_scale_growth = fp16_scale_growth |
| |
|
| | self.model_params = list(self.model.parameters()) |
| | self.master_params = self.model_params |
| | self.param_groups_and_shapes = None |
| | self.lg_loss_scale = initial_lg_loss_scale |
| |
|
| | if self.use_fp16: |
| | self.param_groups_and_shapes = get_param_groups_and_shapes( |
| | self.model.named_parameters() |
| | ) |
| | self.master_params = make_master_params(self.param_groups_and_shapes) |
| | self.model.convert_to_fp16() |
| |
|
| | def zero_grad(self): |
| | zero_grad(self.model_params) |
| |
|
| | def backward(self, loss: th.Tensor): |
| | if self.use_fp16: |
| | loss_scale = 2 ** self.lg_loss_scale |
| | (loss * loss_scale).backward() |
| | else: |
| | loss.backward() |
| |
|
| | def optimize(self, opt: th.optim.Optimizer): |
| | if self.use_fp16: |
| | return self._optimize_fp16(opt) |
| | else: |
| | return self._optimize_normal(opt) |
| |
|
| | def _optimize_fp16(self, opt: th.optim.Optimizer): |
| | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) |
| | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) |
| | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) |
| | if check_overflow(grad_norm): |
| | self.lg_loss_scale -= 1 |
| | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") |
| | zero_master_grads(self.master_params) |
| | return False |
| |
|
| | logger.logkv_mean("grad_norm", grad_norm) |
| | logger.logkv_mean("param_norm", param_norm) |
| |
|
| | for p in self.master_params: |
| | p.grad.mul_(1.0 / (2 ** self.lg_loss_scale)) |
| | opt.step() |
| | zero_master_grads(self.master_params) |
| | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) |
| | self.lg_loss_scale += self.fp16_scale_growth |
| | return True |
| |
|
| | def _optimize_normal(self, opt: th.optim.Optimizer): |
| | grad_norm, param_norm = self._compute_norms() |
| | logger.logkv_mean("grad_norm", grad_norm) |
| | logger.logkv_mean("param_norm", param_norm) |
| | opt.step() |
| | return True |
| |
|
| | def _compute_norms(self, grad_scale=1.0): |
| | grad_norm = 0.0 |
| | param_norm = 0.0 |
| | for p in self.master_params: |
| | with th.no_grad(): |
| | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 |
| | if p.grad is not None: |
| | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 |
| | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) |
| |
|
| | def master_params_to_state_dict(self, master_params): |
| | return master_params_to_state_dict( |
| | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 |
| | ) |
| |
|
| | def state_dict_to_master_params(self, state_dict): |
| | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) |
| |
|
| |
|
| | def check_overflow(value): |
| | return (value == float("inf")) or (value == -float("inf")) or (value != value) |
| |
|