| | from typing import Literal, Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from ..autoencoder.vae import VAE, get_my_vae |
| | from ..bigvgan import BigVGAN |
| | from ..bigvgan_v2.bigvgan import BigVGAN as BigVGANv2 |
| | from ...model.utils.distributions import DiagonalGaussianDistribution |
| |
|
| |
|
| | class AutoEncoderModule(nn.Module): |
| |
|
| | def __init__(self, |
| | *, |
| | vae_ckpt_path, |
| | vocoder_ckpt_path: Optional[str] = None, |
| | mode: Literal['16k', '44k'], |
| | need_vae_encoder: bool = True): |
| | super().__init__() |
| | self.vae: VAE = get_my_vae(mode).eval() |
| | vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu') |
| | self.vae.load_state_dict(vae_state_dict) |
| | self.vae.remove_weight_norm() |
| |
|
| | if mode == '16k': |
| | assert vocoder_ckpt_path is not None |
| | self.vocoder = BigVGAN(vocoder_ckpt_path).eval() |
| | elif mode == '44k': |
| | self.vocoder = BigVGANv2.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x', |
| | use_cuda_kernel=False) |
| | self.vocoder.remove_weight_norm() |
| | else: |
| | raise ValueError(f'Unknown mode: {mode}') |
| |
|
| | for param in self.parameters(): |
| | param.requires_grad = False |
| |
|
| | if not need_vae_encoder: |
| | del self.vae.encoder |
| |
|
| | @torch.inference_mode() |
| | def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution: |
| | return self.vae.encode(x) |
| |
|
| | @torch.inference_mode() |
| | def decode(self, z: torch.Tensor) -> torch.Tensor: |
| | return self.vae.decode(z) |
| |
|
| | @torch.inference_mode() |
| | def vocode(self, spec: torch.Tensor) -> torch.Tensor: |
| | return self.vocoder(spec) |
| |
|