| |
| |
| |
| |
| |
|
|
| """Normalization modules.""" |
|
|
| import typing as tp |
|
|
| import einops |
| import torch |
| from torch import nn |
|
|
|
|
| class ConvLayerNorm(nn.LayerNorm): |
| """ |
| Convolution-friendly LayerNorm that moves channels to last dimensions |
| before running the normalization and moves them back to original position right after. |
| """ |
| def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs): |
| super().__init__(normalized_shape, **kwargs) |
|
|
| def forward(self, x): |
| x = einops.rearrange(x, 'b ... t -> b t ...') |
| x = super().forward(x) |
| x = einops.rearrange(x, 'b t ... -> b ... t') |
| return |
|
|