| | import math |
| | from typing import Optional, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from .config import InitFnType, ModelConfig |
| | from .util import StrEnum |
| |
|
| | __all__ = ["init_weights", "ModuleType"] |
| |
|
| |
|
| | class ModuleType(StrEnum): |
| | in_module = "in" |
| | out_module = "out" |
| | emb = "emb" |
| | final_out = "final_out" |
| |
|
| |
|
| | def init_weights( |
| | config: ModelConfig, |
| | module: Union[nn.Linear, nn.Embedding], |
| | d: Optional[int] = None, |
| | layer_id: Optional[int] = None, |
| | std_factor: float = 1.0, |
| | type_of_module: Optional[ModuleType] = None, |
| | ) -> None: |
| | """ |
| | Initialize weights of a linear or embedding module. |
| | |
| | :param config: The model config. |
| | :param module: The linear or embedding submodule to initialize. |
| | :param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions |
| | for fused layers. |
| | :param layer_id: When set, the standard deviation for the "mitchell" method will be adjusted by |
| | ``1 / sqrt(2 * (layer_id + 1))``. |
| | """ |
| | d = d if d is not None else config.d_model |
| | if config.init_fn == InitFnType.normal: |
| | std = config.init_std * std_factor |
| | if config.init_cutoff_factor is not None: |
| | cutoff_value = config.init_cutoff_factor * std |
| | nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value) |
| | else: |
| | nn.init.normal_(module.weight, mean=0.0, std=std) |
| | elif config.init_fn == InitFnType.mitchell: |
| | std = std_factor / math.sqrt(d) |
| | if layer_id is not None: |
| | std = std / math.sqrt(2 * (layer_id + 1)) |
| | nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std) |
| | elif config.init_fn == InitFnType.kaiming_normal: |
| | nn.init.kaiming_normal_(module.weight, nonlinearity="relu") |
| | elif config.init_fn == InitFnType.fan_in: |
| | std = std_factor / math.sqrt(d) |
| | nn.init.normal_(module.weight, mean=0.0, std=std) |
| | elif config.init_fn == InitFnType.full_megatron: |
| | if type_of_module is None: |
| | raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.") |
| |
|
| | cutoff_factor = config.init_cutoff_factor |
| | if cutoff_factor is None: |
| | cutoff_factor = 3 |
| |
|
| | if type_of_module == ModuleType.in_module: |
| | |
| | std = config.init_std |
| | elif type_of_module == ModuleType.out_module: |
| | |
| | std = config.init_std / math.sqrt(2.0 * config.n_layers) |
| | elif type_of_module == ModuleType.emb: |
| | |
| | |
| | std = config.init_std |
| | elif type_of_module == ModuleType.final_out: |
| | |
| | std = config.d_model**-0.5 |
| | else: |
| | raise RuntimeError(f"Unknown module type '{type_of_module}'") |
| | nn.init.trunc_normal_( |
| | module.weight, |
| | mean=0.0, |
| | std=std, |
| | a=-cutoff_factor * std, |
| | b=cutoff_factor * std, |
| | ) |
| | else: |
| | raise NotImplementedError(config.init_fn) |
| |
|
| | if isinstance(module, nn.Linear): |
| | if module.bias is not None: |
| | nn.init.zeros_(module.bias) |
| |
|
| | if config.init_fn == InitFnType.normal and getattr(module, "_is_residual", False): |
| | with torch.no_grad(): |
| | module.weight.div_(math.sqrt(2 * config.n_layers)) |
| |
|