haiphamcse's picture
Upload folder using huggingface_hub
9855f47 verified
import torch
from typing import Optional
from .config import TransformerConfig
class MaskedGroupNorm(torch.nn.GroupNorm):
def forward(
self, x: torch.Tensor, padding_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
if padding_mask is None:
return super().forward(x)
B, C, L = x.shape
G = self.num_groups
x_grouped = x.view(B, G, C // G, L)
padding_mask_grouped = padding_mask.reshape(B, G, C // G, L).bool()
mean = torch.masked.mean(
x_grouped, mask=padding_mask_grouped, dim=(2, 3), keepdim=True
)
var = torch.masked.var(
x_grouped,
mask=padding_mask_grouped,
dim=(2, 3),
keepdim=True,
unbiased=False,
)
x_norm = (x_grouped - mean) / torch.sqrt(var + self.eps)
x_norm = x_norm.view(B, C, L)
if self.affine:
x_norm = x_norm * self.weight.view(1, -1, 1) + self.bias.view(1, -1, 1)
return x_norm * padding_mask
class ConvBlock1d(torch.nn.Module):
def __init__(self, config: TransformerConfig):
super().__init__()
self.groupnorm = MaskedGroupNorm(num_groups=1, num_channels=config.hidden_size)
self.activation = torch.nn.SiLU()
self.project = torch.nn.Conv1d(
in_channels=config.hidden_size,
out_channels=config.hidden_size,
kernel_size=3,
padding="same",
)
def forward(
self, x: torch.Tensor, padding_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
x = self.groupnorm(x, padding_mask=padding_mask)
x = self.activation(x)
return self.project(x)
class ResnetBlock1d(torch.nn.Module):
def __init__(self, config: TransformerConfig):
super().__init__()
self.block1 = ConvBlock1d(config)
self.block2 = ConvBlock1d(config)
def forward(
self, x: torch.Tensor, padding_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
if padding_mask is not None:
padding_mask = padding_mask.unsqueeze(1).expand_as(x)
h = self.block1(x, padding_mask=padding_mask)
h = self.block2(h, padding_mask=padding_mask)
return h + x