| from typing import Callable |
|
|
| import math |
| import warnings |
|
|
| import torch |
| from torch import nn, Tensor |
|
|
| def named_apply( |
| fn: Callable, |
| module: nn.Module, |
| name: str = "", |
| depth_first: bool = True, |
| include_root: bool = False, |
| ) -> nn.Module: |
| if not depth_first and include_root: |
| fn(module=module, name=name) |
| for child_name, child_module in module.named_children(): |
| child_name = ".".join((name, child_name)) if name else child_name |
| named_apply( |
| fn=fn, |
| module=child_module, |
| name=child_name, |
| depth_first=depth_first, |
| include_root=True, |
| ) |
| if depth_first and include_root: |
| fn(module=module, name=name) |
| return module |
|
|
|
|
| def _no_grad_trunc_normal_(tensor, mean, std, a, b): |
| |
| |
| def norm_cdf(x): |
| |
| return (1. + math.erf(x / math.sqrt(2.))) / 2. |
|
|
| if (mean < a - 2 * std) or (mean > b + 2 * std): |
| warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " |
| "The distribution of values may be incorrect.", |
| stacklevel=2) |
|
|
| with torch.no_grad(): |
| |
| |
| |
| l = norm_cdf((a - mean) / std) |
| u = norm_cdf((b - mean) / std) |
|
|
| |
| |
| tensor.uniform_(2 * l - 1, 2 * u - 1) |
|
|
| |
| |
| tensor.erfinv_() |
|
|
| |
| tensor.mul_(std * math.sqrt(2.)) |
| tensor.add_(mean) |
|
|
| |
| tensor.clamp_(min=a, max=b) |
| return tensor |
|
|
| def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): |
| |
| return _no_grad_trunc_normal_(tensor, mean, std, a, b) |