| import torch |
| from typing import Optional |
| from .config import TransformerConfig |
|
|
|
|
| class MaskedGroupNorm(torch.nn.GroupNorm): |
| def forward( |
| self, x: torch.Tensor, padding_mask: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| if padding_mask is None: |
| return super().forward(x) |
| B, C, L = x.shape |
| G = self.num_groups |
| x_grouped = x.view(B, G, C // G, L) |
| padding_mask_grouped = padding_mask.reshape(B, G, C // G, L).bool() |
| mean = torch.masked.mean( |
| x_grouped, mask=padding_mask_grouped, dim=(2, 3), keepdim=True |
| ) |
| var = torch.masked.var( |
| x_grouped, |
| mask=padding_mask_grouped, |
| dim=(2, 3), |
| keepdim=True, |
| unbiased=False, |
| ) |
| x_norm = (x_grouped - mean) / torch.sqrt(var + self.eps) |
| x_norm = x_norm.view(B, C, L) |
| if self.affine: |
| x_norm = x_norm * self.weight.view(1, -1, 1) + self.bias.view(1, -1, 1) |
| return x_norm * padding_mask |
|
|
|
|
| class ConvBlock1d(torch.nn.Module): |
| def __init__(self, config: TransformerConfig): |
| super().__init__() |
| self.groupnorm = MaskedGroupNorm(num_groups=1, num_channels=config.hidden_size) |
| self.activation = torch.nn.SiLU() |
| self.project = torch.nn.Conv1d( |
| in_channels=config.hidden_size, |
| out_channels=config.hidden_size, |
| kernel_size=3, |
| padding="same", |
| ) |
|
|
| def forward( |
| self, x: torch.Tensor, padding_mask: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| x = self.groupnorm(x, padding_mask=padding_mask) |
| x = self.activation(x) |
| return self.project(x) |
|
|
|
|
| class ResnetBlock1d(torch.nn.Module): |
| def __init__(self, config: TransformerConfig): |
| super().__init__() |
| self.block1 = ConvBlock1d(config) |
| self.block2 = ConvBlock1d(config) |
|
|
| def forward( |
| self, x: torch.Tensor, padding_mask: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| if padding_mask is not None: |
| padding_mask = padding_mask.unsqueeze(1).expand_as(x) |
| h = self.block1(x, padding_mask=padding_mask) |
| h = self.block2(h, padding_mask=padding_mask) |
| return h + x |
|
|