| | from functools import reduce |
| | import math |
| | import numpy as np |
| | import torch |
| | from torch import nn |
| | from torch.nn import functional as F |
| |
|
| | from torch.backends.cuda import sdp_kernel |
| | from packaging import version |
| |
|
| | from dac.nn.layers import Snake1d |
| |
|
| | class ResidualBlock(nn.Module): |
| | def __init__(self, main, skip=None): |
| | super().__init__() |
| | self.main = nn.Sequential(*main) |
| | self.skip = skip if skip else nn.Identity() |
| |
|
| | def forward(self, input): |
| | return self.main(input) + self.skip(input) |
| |
|
| | class ResConvBlock(ResidualBlock): |
| | def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False): |
| | skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False) |
| | super().__init__([ |
| | nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias), |
| | nn.GroupNorm(1, c_mid), |
| | Snake1d(c_mid) if use_snake else nn.GELU(), |
| | nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias), |
| | nn.GroupNorm(1, c_out) if not is_last else nn.Identity(), |
| | (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(), |
| | ], skip) |
| |
|
| | class SelfAttention1d(nn.Module): |
| | def __init__(self, c_in, n_head=1, dropout_rate=0.): |
| | super().__init__() |
| | assert c_in % n_head == 0 |
| | self.norm = nn.GroupNorm(1, c_in) |
| | self.n_head = n_head |
| | self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1) |
| | self.out_proj = nn.Conv1d(c_in, c_in, 1) |
| | self.dropout = nn.Dropout(dropout_rate, inplace=True) |
| |
|
| | def forward(self, input): |
| | n, c, s = input.shape |
| | qkv = self.qkv_proj(self.norm(input)) |
| | qkv = qkv.view( |
| | [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3) |
| | q, k, v = qkv.chunk(3, dim=1) |
| | scale = k.shape[3]**-0.25 |
| |
|
| | att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3) |
| | y = (att @ v).transpose(2, 3).contiguous().view([n, c, s]) |
| |
|
| | return input + self.dropout(self.out_proj(y)) |
| |
|
| | class SkipBlock(nn.Module): |
| | def __init__(self, *main): |
| | super().__init__() |
| | self.main = nn.Sequential(*main) |
| |
|
| | def forward(self, input): |
| | return torch.cat([self.main(input), input], dim=1) |
| |
|
| | class FourierFeatures(nn.Module): |
| | def __init__(self, in_features, out_features, std=1.): |
| | super().__init__() |
| | assert out_features % 2 == 0 |
| | self.weight = nn.Parameter(torch.randn( |
| | [out_features // 2, in_features]) * std) |
| |
|
| | def forward(self, input): |
| | f = 2 * math.pi * input @ self.weight.T |
| | return torch.cat([f.cos(), f.sin()], dim=-1) |
| |
|
| | def expand_to_planes(input, shape): |
| | return input[..., None].repeat([1, 1, shape[2]]) |
| |
|
| | _kernels = { |
| | 'linear': |
| | [1 / 8, 3 / 8, 3 / 8, 1 / 8], |
| | 'cubic': |
| | [-0.01171875, -0.03515625, 0.11328125, 0.43359375, |
| | 0.43359375, 0.11328125, -0.03515625, -0.01171875], |
| | 'lanczos3': |
| | [0.003689131001010537, 0.015056144446134567, -0.03399861603975296, |
| | -0.066637322306633, 0.13550527393817902, 0.44638532400131226, |
| | 0.44638532400131226, 0.13550527393817902, -0.066637322306633, |
| | -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] |
| | } |
| |
|
| | class Downsample1d(nn.Module): |
| | def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): |
| | super().__init__() |
| | self.pad_mode = pad_mode |
| | kernel_1d = torch.tensor(_kernels[kernel]) |
| | self.pad = kernel_1d.shape[0] // 2 - 1 |
| | self.register_buffer('kernel', kernel_1d) |
| | self.channels_last = channels_last |
| | |
| | def forward(self, x): |
| | if self.channels_last: |
| | x = x.permute(0, 2, 1) |
| | x = F.pad(x, (self.pad,) * 2, self.pad_mode) |
| | weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) |
| | indices = torch.arange(x.shape[1], device=x.device) |
| | weight[indices, indices] = self.kernel.to(weight) |
| | x = F.conv1d(x, weight, stride=2) |
| | if self.channels_last: |
| | x = x.permute(0, 2, 1) |
| | return x |
| |
|
| |
|
| | class Upsample1d(nn.Module): |
| | def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): |
| | super().__init__() |
| | self.pad_mode = pad_mode |
| | kernel_1d = torch.tensor(_kernels[kernel]) * 2 |
| | self.pad = kernel_1d.shape[0] // 2 - 1 |
| | self.register_buffer('kernel', kernel_1d) |
| | self.channels_last = channels_last |
| | |
| | def forward(self, x): |
| | if self.channels_last: |
| | x = x.permute(0, 2, 1) |
| | x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode) |
| | weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) |
| | indices = torch.arange(x.shape[1], device=x.device) |
| | weight[indices, indices] = self.kernel.to(weight) |
| | x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1) |
| | if self.channels_last: |
| | x = x.permute(0, 2, 1) |
| | return x |
| | |
| | def Downsample1d_2( |
| | in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 |
| | ) -> nn.Module: |
| | assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" |
| |
|
| | return nn.Conv1d( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | kernel_size=factor * kernel_multiplier + 1, |
| | stride=factor, |
| | padding=factor * (kernel_multiplier // 2), |
| | ) |
| |
|
| |
|
| | def Upsample1d_2( |
| | in_channels: int, out_channels: int, factor: int, use_nearest: bool = False |
| | ) -> nn.Module: |
| |
|
| | if factor == 1: |
| | return nn.Conv1d( |
| | in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1 |
| | ) |
| |
|
| | if use_nearest: |
| | return nn.Sequential( |
| | nn.Upsample(scale_factor=factor, mode="nearest"), |
| | nn.Conv1d( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | kernel_size=3, |
| | padding=1, |
| | ), |
| | ) |
| | else: |
| | return nn.ConvTranspose1d( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | kernel_size=factor * 2, |
| | stride=factor, |
| | padding=factor // 2 + factor % 2, |
| | output_padding=factor % 2, |
| | ) |
| |
|
| | def zero_init(layer): |
| | nn.init.zeros_(layer.weight) |
| | if layer.bias is not None: |
| | nn.init.zeros_(layer.bias) |
| | return layer |
| |
|
| | def rms_norm(x, scale, eps): |
| | dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) |
| | mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) |
| | scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) |
| | return x * scale.to(x.dtype) |
| |
|
| | |
| |
|
| | class AdaRMSNorm(nn.Module): |
| | def __init__(self, features, cond_features, eps=1e-6): |
| | super().__init__() |
| | self.eps = eps |
| | self.linear = zero_init(nn.Linear(cond_features, features, bias=False)) |
| | |
| | def extra_repr(self): |
| | return f"eps={self.eps}," |
| |
|
| | def forward(self, x, cond): |
| | return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps) |
| | |
| | def normalize(x, eps=1e-4): |
| | dim = list(range(1, x.ndim)) |
| | n = torch.linalg.vector_norm(x, dim=dim, keepdim=True) |
| | alpha = np.sqrt(n.numel() / x.numel()) |
| | return x / torch.add(eps, n, alpha=alpha) |
| |
|
| | class ForcedWNConv1d(nn.Module): |
| | def __init__(self, in_channels, out_channels, kernel_size=1): |
| | super().__init__() |
| | self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size])) |
| |
|
| | def forward(self, x): |
| | if self.training: |
| | with torch.no_grad(): |
| | self.weight.copy_(normalize(self.weight)) |
| | |
| | fan_in = self.weight[0].numel() |
| |
|
| | w = normalize(self.weight) / math.sqrt(fan_in) |
| |
|
| | return F.conv1d(x, w, padding='same') |
| | |
| | |
| |
|
| | |
| | use_compile = False |
| |
|
| | def compile(function, *args, **kwargs): |
| | if not use_compile: |
| | return function |
| | try: |
| | return torch.compile(function, *args, **kwargs) |
| | except RuntimeError: |
| | return function |
| |
|
| |
|
| | @compile |
| | def linear_geglu(x, weight, bias=None): |
| | x = x @ weight.mT |
| | if bias is not None: |
| | x = x + bias |
| | x, gate = x.chunk(2, dim=-1) |
| | return x * F.gelu(gate) |
| |
|
| |
|
| | @compile |
| | def rms_norm(x, scale, eps): |
| | dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) |
| | mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) |
| | scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) |
| | return x * scale.to(x.dtype) |
| |
|
| | |
| |
|
| | class LinearGEGLU(nn.Linear): |
| | def __init__(self, in_features, out_features, bias=True): |
| | super().__init__(in_features, out_features * 2, bias=bias) |
| | self.out_features = out_features |
| |
|
| | def forward(self, x): |
| | return linear_geglu(x, self.weight, self.bias) |
| |
|
| |
|
| | class RMSNorm(nn.Module): |
| | def __init__(self, shape, fix_scale = False, eps=1e-6): |
| | super().__init__() |
| | self.eps = eps |
| |
|
| | if fix_scale: |
| | self.register_buffer("scale", torch.ones(shape)) |
| | else: |
| | self.scale = nn.Parameter(torch.ones(shape)) |
| |
|
| | def extra_repr(self): |
| | return f"shape={tuple(self.scale.shape)}, eps={self.eps}" |
| |
|
| | def forward(self, x): |
| | return rms_norm(x, self.scale, self.eps) |
| |
|
| | def snake_beta(x, alpha, beta): |
| | return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | class SnakeBeta(nn.Module): |
| |
|
| | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): |
| | super(SnakeBeta, self).__init__() |
| | self.in_features = in_features |
| |
|
| | |
| | self.alpha_logscale = alpha_logscale |
| | if self.alpha_logscale: |
| | self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) |
| | self.beta = nn.Parameter(torch.zeros(in_features) * alpha) |
| | else: |
| | self.alpha = nn.Parameter(torch.ones(in_features) * alpha) |
| | self.beta = nn.Parameter(torch.ones(in_features) * alpha) |
| |
|
| | self.alpha.requires_grad = alpha_trainable |
| | self.beta.requires_grad = alpha_trainable |
| |
|
| | self.no_div_by_zero = 0.000000001 |
| |
|
| | def forward(self, x): |
| | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) |
| | beta = self.beta.unsqueeze(0).unsqueeze(-1) |
| | if self.alpha_logscale: |
| | alpha = torch.exp(alpha) |
| | beta = torch.exp(beta) |
| | x = snake_beta(x, alpha, beta) |
| |
|
| | return x |