| import sys |
| |
| import torch |
| from torch import nn |
| import numpy as np |
| from vq.module import WNConv1d, EncoderBlock, ResLSTM |
| from vq.alias_free_torch import * |
| from vq import activations |
| from vq.bs_roformer5 import TransformerBlock |
| |
| from torchtune.modules import RotaryPositionalEmbeddings |
| import vq.blocks as blocks |
| from torch.nn import utils |
| def init_weights(m): |
| if isinstance(m, nn.Conv1d): |
| nn.init.trunc_normal_(m.weight, std=0.02) |
| nn.init.constant_(m.bias, 0) |
|
|
| class CodecEncoder(nn.Module): |
| def __init__(self, |
| ngf=48, |
| use_rnn=True, |
| rnn_bidirectional=False, |
| rnn_num_layers=2, |
| up_ratios=(2, 2, 4, 4, 5), |
| dilations=(1, 3, 9), |
| out_channels=1024): |
| super().__init__() |
| self.hop_length = np.prod(up_ratios) |
| self.ngf = ngf |
| self.up_ratios = up_ratios |
|
|
| |
| d_model = ngf |
| self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] |
|
|
| |
| for i, stride in enumerate(up_ratios): |
| d_model *= 2 |
| self.block += [EncoderBlock(d_model, stride=stride, dilations=dilations)] |
| |
| if use_rnn: |
| self.block += [ |
| ResLSTM(d_model, |
| num_layers=rnn_num_layers, |
| bidirectional=rnn_bidirectional |
| ) |
| ] |
| |
| self.block += [ |
| Activation1d(activation=activations.SnakeBeta(d_model, alpha_logscale=True)), |
| WNConv1d(d_model, out_channels, kernel_size=3, padding=1), |
| ] |
|
|
| |
| self.block = nn.Sequential(*self.block) |
| self.enc_dim = d_model |
| |
| self.reset_parameters() |
|
|
| def forward(self, x): |
| out = self.block(x) |
| return out |
|
|
| def inference(self, x): |
| return self.block(x) |
|
|
| def remove_weight_norm(self): |
| """Remove weight normalization module from all of the layers.""" |
|
|
| def _remove_weight_norm(m): |
| try: |
| torch.nn.utils.remove_weight_norm(m) |
| except ValueError: |
| return |
|
|
| self.apply(_remove_weight_norm) |
|
|
| def apply_weight_norm(self): |
| """Apply weight normalization module from all of the layers.""" |
|
|
| def _apply_weight_norm(m): |
| if isinstance(m, nn.Conv1d): |
| torch.nn.utils.weight_norm(m) |
|
|
| self.apply(_apply_weight_norm) |
|
|
| def reset_parameters(self): |
| self.apply(init_weights) |
|
|
|
|
| class Transpose(nn.Module): |
| def __init__(self, dim1, dim2): |
| super(Transpose, self).__init__() |
| self.dim1 = dim1 |
| self.dim2 = dim2 |
|
|
| def forward(self, x): |
| return x.transpose(self.dim1, self.dim2) |
|
|
| class CodecEncoder_Transformer(nn.Module): |
| def __init__(self, |
| ngf=48, |
| up_ratios=[2, 2, 4, 4, 5], |
| dilations=(1, 3, 9), |
| hidden_dim=1024, |
| depth=12, |
| heads=12, |
| pos_meb_dim=64, |
| ): |
| super().__init__() |
| self.hop_length = np.prod(up_ratios) |
| self.ngf =ngf |
| self.up_ratios = up_ratios |
| |
| d_model = ngf |
| self.conv_blocks = [WNConv1d(1, d_model, kernel_size=7, padding=3)] |
|
|
| |
| for i, stride in enumerate(up_ratios): |
| d_model *= 2 |
| self.conv_blocks += [EncoderBlock(d_model, stride=stride, dilations=dilations)] |
| |
| self.conv_blocks = nn.Sequential(*self.conv_blocks) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| self.conv_final_block = [ |
| Activation1d(activation=activations.SnakeBeta(d_model, alpha_logscale=True)), |
| WNConv1d(d_model, hidden_dim, kernel_size=3, padding=1), |
| ] |
| self.conv_final_block = nn.Sequential(*self.conv_final_block) |
| |
| self.reset_parameters() |
|
|
| def forward(self, x): |
| x = self.conv_blocks(x) |
| |
| |
| |
| |
| x = self.conv_final_block (x) |
| x = x.permute(0, 2, 1) |
| return x |
|
|
| def inference(self, x): |
| return self.block(x) |
|
|
| def remove_weight_norm(self): |
| """Remove weight normalization module from all of the layers.""" |
|
|
| def _remove_weight_norm(m): |
| try: |
| torch.nn.utils.remove_weight_norm(m) |
| except ValueError: |
| return |
|
|
| self.apply(_remove_weight_norm) |
|
|
| def apply_weight_norm(self): |
| """Apply weight normalization module from all of the layers.""" |
|
|
| def _apply_weight_norm(m): |
| if isinstance(m, nn.Conv1d): |
| torch.nn.utils.weight_norm(m) |
|
|
| self.apply(_apply_weight_norm) |
|
|
| def reset_parameters(self): |
| self.apply(init_weights) |
|
|
|
|
|
|
| class Codec_oobleck_Transformer(nn.Module): |
| def __init__(self, |
| ngf=32, |
| up_ratios=(2, 2,4,4, 5), |
| dilations=(1, 3, 9), |
| hidden_dim=1024, |
| depth=12, |
| heads=16, |
| pos_meb_dim=64, |
| ): |
| super().__init__() |
| self.hop_length = np.prod(up_ratios) |
| self.ngf =ngf |
| self.up_ratios = up_ratios |
| self.hidden_dim = hidden_dim |
| |
|
|
| self.conv_blocks = blocks.DilatedResidualEncoder( |
| capacity=ngf, |
| dilated_unit=self.dilated_unit, |
| downsampling_unit=self.downsampling_unit, |
| ratios=up_ratios, |
| dilations=dilations, |
| pre_network_conv=self.pre_conv, |
| post_network_conv=self.post_conv, |
| ) |
| |
| |
| time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) |
| |
| transformer_blocks = [ |
| TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) |
| for _ in range(depth) |
| ] |
| |
| self.transformers = nn.Sequential(*transformer_blocks) |
|
|
| self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) |
| |
| |
| self.reset_parameters() |
|
|
| def forward(self, x): |
| x = self.conv_blocks(x) |
| x = x.permute(0, 2, 1) |
| x= self.transformers(x) |
| x = self.final_layer_norm(x) |
| return x |
|
|
| def inference(self, x): |
| return self.block(x) |
|
|
| def remove_weight_norm(self): |
| """Remove weight normalization module from all of the layers.""" |
|
|
| def _remove_weight_norm(m): |
| try: |
| torch.nn.utils.remove_weight_norm(m) |
| except ValueError: |
| return |
|
|
| self.apply(_remove_weight_norm) |
|
|
| def apply_weight_norm(self): |
| """Apply weight normalization module from all of the layers.""" |
|
|
| def _apply_weight_norm(m): |
| if isinstance(m, nn.Conv1d): |
| torch.nn.utils.weight_norm(m) |
|
|
| self.apply(_apply_weight_norm) |
|
|
| def reset_parameters(self): |
| self.apply(init_weights) |
|
|
| def dilated_unit(self,hidden_dim, dilation): |
| return blocks.DilatedConvolutionalUnit(hidden_dim, |
| dilation, |
| kernel_size=3, |
| activation=nn.ReLU, |
| normalization=utils.weight_norm) |
|
|
| def downsampling_unit(self, input_dim: int, output_dim: int, stride: int): |
| return blocks.DownsamplingUnit(input_dim, |
| output_dim, |
| stride, |
| nn.ReLU, |
| normalization=utils.weight_norm) |
|
|
| def pre_conv(self,out_channels): |
| return nn.Conv1d(1, out_channels, 1) |
|
|
| def post_conv(self,in_channels): |
| return nn.Conv1d(in_channels, self.hidden_dim, 1) |
|
|
|
|
|
|
|
|
|
|
| class CodecEncoder_only_Transformer(nn.Module): |
| def __init__(self,hidden_dim=1024,depth=12,heads=16,pos_meb_dim=64): |
| super().__init__() |
| |
|
|
| depth = depth |
| time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) |
| |
| |
| transformer_blocks = [ |
| TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) |
| for _ in range(depth) |
| ] |
| |
| |
| self.transformers = nn.Sequential(*transformer_blocks) |
|
|
| self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) |
|
|
| def forward(self, x: torch.Tensor ) -> torch.Tensor: |
| |
| |
| |
| x= self.transformers(x) |
| x = self.final_layer_norm(x) |
| |
| return x |
|
|
|
|
|
|
|
|
| |
|
|
|
|
| def get_model_size(model): |
| |
| total_params = sum(p.numel() for p in model.parameters()) |
| |
| |
| model_size_bytes = total_params |
| |
| |
| model_size_mb = model_size_bytes / (1024 ** 2) |
| |
| return total_params, model_size_mb |
|
|
| if __name__ == '__main__': |
| model = Codec_oobleck_Transformer() |
| x = torch.randn(1, 1, 16000) |
| output = model(x) |
| print("Output shape:", output.shape) |
|
|