| | """Utility functions for initializing weights and biases.""" |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import math |
| |
|
| | import torch |
| |
|
| |
|
| | def _calculate_fan(linear_weight_shape, fan="fan_in"): |
| | fan_out, fan_in = linear_weight_shape |
| |
|
| | if fan == "fan_in": |
| | f = fan_in |
| | elif fan == "fan_out": |
| | f = fan_out |
| | elif fan == "fan_avg": |
| | f = (fan_in + fan_out) / 2 |
| | else: |
| | raise ValueError("Invalid fan option") |
| |
|
| | return f |
| |
|
| |
|
| | def trunc_normal_init_(weights, scale=1.0, fan="fan_in"): |
| | shape = weights.shape |
| | f = _calculate_fan(shape, fan) |
| | scale = scale / max(1, f) |
| | std = math.sqrt(scale) |
| | with torch.no_grad(): |
| | torch.nn.init.trunc_normal_(weights, mean=0.0, std=std, a=-2 * std, b=2 * std) |
| |
|
| |
|
| | def lecun_normal_init_(weights): |
| | trunc_normal_init_(weights, scale=1.0) |
| |
|
| |
|
| | def he_normal_init_(weights): |
| | trunc_normal_init_(weights, scale=2.0) |
| |
|
| |
|
| | def glorot_uniform_init_(weights): |
| | torch.nn.init.xavier_uniform_(weights, gain=1) |
| |
|
| |
|
| | def final_init_(weights): |
| | with torch.no_grad(): |
| | weights.fill_(0.0) |
| |
|
| |
|
| | def gating_init_(weights): |
| | with torch.no_grad(): |
| | weights.fill_(0.0) |
| |
|
| |
|
| | def bias_init_zero_(bias): |
| | with torch.no_grad(): |
| | bias.fill_(0.0) |
| |
|
| |
|
| | def bias_init_one_(bias): |
| | with torch.no_grad(): |
| | bias.fill_(1.0) |
| |
|
| |
|
| | def normal_init_(weights): |
| | torch.nn.init.kaiming_normal_(weights, nonlinearity="linear") |
| |
|
| |
|
| | def ipa_point_weights_init_(weights): |
| | with torch.no_grad(): |
| | softplus_inverse_1 = 0.541324854612918 |
| | weights.fill_(softplus_inverse_1) |
| |
|