|
|
"""Utility functions for initializing weights and biases."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
def _calculate_fan(linear_weight_shape, fan="fan_in"):
|
|
|
fan_out, fan_in = linear_weight_shape
|
|
|
|
|
|
if fan == "fan_in":
|
|
|
f = fan_in
|
|
|
elif fan == "fan_out":
|
|
|
f = fan_out
|
|
|
elif fan == "fan_avg":
|
|
|
f = (fan_in + fan_out) / 2
|
|
|
else:
|
|
|
raise ValueError("Invalid fan option")
|
|
|
|
|
|
return f
|
|
|
|
|
|
|
|
|
def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
|
|
|
shape = weights.shape
|
|
|
f = _calculate_fan(shape, fan)
|
|
|
scale = scale / max(1, f)
|
|
|
std = math.sqrt(scale)
|
|
|
with torch.no_grad():
|
|
|
torch.nn.init.trunc_normal_(weights, mean=0.0, std=std, a=-2 * std, b=2 * std)
|
|
|
|
|
|
|
|
|
def lecun_normal_init_(weights):
|
|
|
trunc_normal_init_(weights, scale=1.0)
|
|
|
|
|
|
|
|
|
def he_normal_init_(weights):
|
|
|
trunc_normal_init_(weights, scale=2.0)
|
|
|
|
|
|
|
|
|
def glorot_uniform_init_(weights):
|
|
|
torch.nn.init.xavier_uniform_(weights, gain=1)
|
|
|
|
|
|
|
|
|
def final_init_(weights):
|
|
|
with torch.no_grad():
|
|
|
weights.fill_(0.0)
|
|
|
|
|
|
|
|
|
def gating_init_(weights):
|
|
|
with torch.no_grad():
|
|
|
weights.fill_(0.0)
|
|
|
|
|
|
|
|
|
def bias_init_zero_(bias):
|
|
|
with torch.no_grad():
|
|
|
bias.fill_(0.0)
|
|
|
|
|
|
|
|
|
def bias_init_one_(bias):
|
|
|
with torch.no_grad():
|
|
|
bias.fill_(1.0)
|
|
|
|
|
|
|
|
|
def normal_init_(weights):
|
|
|
torch.nn.init.kaiming_normal_(weights, nonlinearity="linear")
|
|
|
|
|
|
|
|
|
def ipa_point_weights_init_(weights):
|
|
|
with torch.no_grad():
|
|
|
softplus_inverse_1 = 0.541324854612918
|
|
|
weights.fill_(softplus_inverse_1)
|
|
|
|