import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils import weight_norm from vector_quantize_pytorch import ResidualVQ class CausalConv1d(nn.Conv1d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.causal_padding = self.dilation[0] * (self.kernel_size[0] - 1) def forward(self, x): return self._conv_forward(F.pad(x, [self.causal_padding, 0]), self.weight, self.bias) class CausalConvTranspose1d(nn.ConvTranspose1d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.causal_padding = self.dilation[0] * (self.kernel_size[0] - 1) + self.output_padding[0] + 1 - self.stride[0] def forward(self, x, output_size=None): if self.padding_mode != 'zeros': raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d') assert isinstance(self.padding, tuple) output_padding = self._output_padding( x, output_size, self.stride, self.padding, self.kernel_size, self.dilation) return F.conv_transpose1d( x, self.weight, self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation)[...,:-self.causal_padding] class ResidualUnit(nn.Module): def __init__(self, in_channels, out_channels, dilation): super().__init__() self.dilation = dilation self.layers = nn.Sequential( CausalConv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=7, dilation=dilation), nn.ELU(), nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) ) def forward(self, x): return x + self.layers(x) class EncoderBlock(nn.Module): def __init__(self, out_channels, stride): super().__init__() self.layers = nn.Sequential( ResidualUnit(in_channels=out_channels//2, out_channels=out_channels//2, dilation=1), nn.ELU(), ResidualUnit(in_channels=out_channels//2, out_channels=out_channels//2, dilation=3), nn.ELU(), ResidualUnit(in_channels=out_channels//2, out_channels=out_channels//2, dilation=9), nn.ELU(), CausalConv1d(in_channels=out_channels//2, out_channels=out_channels, kernel_size=2*stride, stride=stride) ) def forward(self, x): return self.layers(x) class DecoderBlock(nn.Module): def __init__(self, out_channels, stride): super().__init__() self.layers = nn.Sequential( CausalConvTranspose1d(in_channels=2*out_channels, out_channels=out_channels, kernel_size=2*stride, stride=stride), nn.ELU(), ResidualUnit(in_channels=out_channels, out_channels=out_channels, dilation=1), nn.ELU(), ResidualUnit(in_channels=out_channels, out_channels=out_channels, dilation=3), nn.ELU(), ResidualUnit(in_channels=out_channels, out_channels=out_channels, dilation=9), ) def forward(self, x): return self.layers(x) class Encoder(nn.Module): def __init__(self, C, D): super().__init__() self.layers = nn.Sequential( CausalConv1d(in_channels=2, out_channels=C, kernel_size=7), nn.ELU(), EncoderBlock(out_channels=2*C, stride=2), nn.ELU(), EncoderBlock(out_channels=4*C, stride=4), nn.ELU(), EncoderBlock(out_channels=8*C, stride=5), nn.ELU(), # EncoderBlock(out_channels=16*C, stride=8), # nn.ELU(), # CausalConv1d(in_channels=16*C, out_channels=D, kernel_size=3) CausalConv1d(in_channels=8*C, out_channels=D, kernel_size=3) ) def forward(self, x): return self.layers(x) class Decoder(nn.Module): def __init__(self, C, D): super().__init__() self.layers = nn.Sequential( CausalConv1d(in_channels=D, out_channels=8*C, kernel_size=7), # CausalConv1d(in_channels=D, out_channels=16*C, kernel_size=7), # nn.ELU(), # DecoderBlock(out_channels=8*C, stride=8), nn.ELU(), DecoderBlock(out_channels=4*C, stride=5), nn.ELU(), DecoderBlock(out_channels=2*C, stride=4), nn.ELU(), DecoderBlock(out_channels=C, stride=2), nn.ELU(), CausalConv1d(in_channels=C, out_channels=2, kernel_size=7) ) def forward(self, x): return self.layers(x) class SoundStream(nn.Module): def __init__(self, C, D, n_q, codebook_size): super().__init__() self.encoder = Encoder(C=C, D=D) self.quantizer = ResidualVQ( num_quantizers=n_q, dim=D, codebook_size=codebook_size, kmeans_init=True, kmeans_iters=100, threshold_ema_dead_code=2 ) self.decoder = Decoder(C=C, D=D) @staticmethod def pad_to_multiple(x, multiple): """ x: [B, C, T] multiple: int, e.g., 320 return: padded_x, original_length """ B, C, T = x.shape target_len = ((T + multiple - 1) // multiple) * multiple pad_len = target_len - T padded_x = F.pad(x, (0, pad_len), mode='reflect') return padded_x, T @staticmethod def crop_to_length(x, original_length): return x[..., :original_length] def forward(self, x): e = self.encoder(x) # [B, D, T'] e = e.permute(0, 2, 1) # → [B, T', D] quantized, _, _ = self.quantizer(e) quantized = quantized.permute(0, 2, 1) # → [B, D, T'] o = self.decoder(quantized) # → [B, 2, T_padded] return o