| | import torch |
| | import torch.nn as nn |
| |
|
| | from modules.ChatTTS.ChatTTS.model.dvae import ConvNeXtBlock, DVAEDecoder |
| |
|
| | from .wavenet import WaveNet |
| |
|
| |
|
| | def get_encoder_config(decoder: DVAEDecoder) -> dict[str, int | bool]: |
| | return { |
| | "idim": decoder.conv_out.out_channels, |
| | "odim": decoder.conv_in[0].in_channels, |
| | "n_layer": len(decoder.decoder_block), |
| | "bn_dim": decoder.conv_in[0].out_channels, |
| | "hidden": decoder.conv_in[2].out_channels, |
| | "kernel": decoder.decoder_block[0].dwconv.kernel_size[0], |
| | "dilation": decoder.decoder_block[0].dwconv.dilation[0], |
| | "down": decoder.up, |
| | } |
| |
|
| |
|
| | class DVAEEncoder(nn.Module): |
| | def __init__( |
| | self, |
| | idim: int, |
| | odim: int, |
| | n_layer: int = 12, |
| | bn_dim: int = 64, |
| | hidden: int = 256, |
| | kernel: int = 7, |
| | dilation: int = 2, |
| | down: bool = False, |
| | ) -> None: |
| | super().__init__() |
| | self.wavenet = WaveNet( |
| | input_channels=100, |
| | residual_channels=idim, |
| | residual_layers=20, |
| | dilation_cycle=4, |
| | ) |
| | self.conv_in_transpose = nn.ConvTranspose1d( |
| | idim, hidden, kernel_size=1, bias=False |
| | ) |
| | |
| | |
| | |
| | |
| | self.encoder_block = nn.ModuleList( |
| | [ |
| | ConvNeXtBlock( |
| | hidden, |
| | hidden * 4, |
| | kernel, |
| | dilation, |
| | ) |
| | for _ in range(n_layer) |
| | ] |
| | ) |
| | self.conv_out_transpose = nn.Sequential( |
| | nn.Conv1d(hidden, bn_dim, 3, 1, 1), |
| | nn.GELU(), |
| | nn.Conv1d(bn_dim, odim, 3, 1, 1), |
| | ) |
| |
|
| | def forward( |
| | self, |
| | audio_mel_specs: torch.Tensor, |
| | audio_attention_mask: torch.Tensor, |
| | conditioning=None, |
| | ) -> torch.Tensor: |
| | mel_attention_mask = ( |
| | audio_attention_mask.unsqueeze(-1).repeat(1, 1, 2).flatten(1) |
| | ) |
| | x: torch.Tensor = self.wavenet( |
| | audio_mel_specs.transpose(1, 2) |
| | ) |
| | x = x * mel_attention_mask.unsqueeze(1) |
| | x = self.conv_in_transpose(x) |
| | for f in self.encoder_block: |
| | x = f(x, conditioning) |
| | x = self.conv_out_transpose(x) |
| | x = ( |
| | x.view(x.size(0), x.size(1), 2, x.size(2) // 2) |
| | .permute(0, 3, 1, 2) |
| | .flatten(2) |
| | ) |
| | return x |
| |
|