| | import torch |
| | import math |
| | import numpy as np |
| |
|
| | from torch import nn |
| | from torch.nn import functional as F |
| | from torchaudio import transforms as T |
| | from alias_free_torch import Activation1d |
| | from .nn.layers import WNConv1d, WNConvTranspose1d |
| | from typing import Literal, Dict, Any |
| |
|
| | |
| | from .utils import prepare_audio |
| | from .blocks import SnakeBeta |
| | from .bottleneck import Bottleneck, DiscreteBottleneck |
| | from .factory import create_pretransform_from_config, create_bottleneck_from_config |
| | from .pretransforms import Pretransform |
| |
|
| | def checkpoint(function, *args, **kwargs): |
| | kwargs.setdefault("use_reentrant", False) |
| | return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) |
| |
|
| | def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module: |
| | if activation == "elu": |
| | act = nn.ELU() |
| | elif activation == "snake": |
| | act = SnakeBeta(channels) |
| | elif activation == "none": |
| | act = nn.Identity() |
| | else: |
| | raise ValueError(f"Unknown activation {activation}") |
| | |
| | if antialias: |
| | act = Activation1d(act) |
| | |
| | return act |
| |
|
| | class ResidualUnit(nn.Module): |
| | def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False): |
| | super().__init__() |
| | |
| | self.dilation = dilation |
| |
|
| | padding = (dilation * (7-1)) // 2 |
| |
|
| | self.layers = nn.Sequential( |
| | get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), |
| | WNConv1d(in_channels=in_channels, out_channels=out_channels, |
| | kernel_size=7, dilation=dilation, padding=padding), |
| | get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), |
| | WNConv1d(in_channels=out_channels, out_channels=out_channels, |
| | kernel_size=1) |
| | ) |
| |
|
| | def forward(self, x): |
| | res = x |
| | |
| | |
| | x = self.layers(x) |
| |
|
| | return x + res |
| |
|
| | class EncoderBlock(nn.Module): |
| | def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False): |
| | super().__init__() |
| |
|
| | self.layers = nn.Sequential( |
| | ResidualUnit(in_channels=in_channels, |
| | out_channels=in_channels, dilation=1, use_snake=use_snake), |
| | ResidualUnit(in_channels=in_channels, |
| | out_channels=in_channels, dilation=3, use_snake=use_snake), |
| | ResidualUnit(in_channels=in_channels, |
| | out_channels=in_channels, dilation=9, use_snake=use_snake), |
| | get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), |
| | WNConv1d(in_channels=in_channels, out_channels=out_channels, |
| | kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)), |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.layers(x) |
| |
|
| | class DecoderBlock(nn.Module): |
| | def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False): |
| | super().__init__() |
| |
|
| | if use_nearest_upsample: |
| | upsample_layer = nn.Sequential( |
| | nn.Upsample(scale_factor=stride, mode="nearest"), |
| | WNConv1d(in_channels=in_channels, |
| | out_channels=out_channels, |
| | kernel_size=2*stride, |
| | stride=1, |
| | bias=False, |
| | padding='same') |
| | ) |
| | else: |
| | upsample_layer = WNConvTranspose1d(in_channels=in_channels, |
| | out_channels=out_channels, |
| | kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)) |
| |
|
| | self.layers = nn.Sequential( |
| | get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), |
| | upsample_layer, |
| | ResidualUnit(in_channels=out_channels, out_channels=out_channels, |
| | dilation=1, use_snake=use_snake), |
| | ResidualUnit(in_channels=out_channels, out_channels=out_channels, |
| | dilation=3, use_snake=use_snake), |
| | ResidualUnit(in_channels=out_channels, out_channels=out_channels, |
| | dilation=9, use_snake=use_snake), |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.layers(x) |
| |
|
| | class OobleckEncoder(nn.Module): |
| | def __init__(self, |
| | in_channels=2, |
| | channels=128, |
| | latent_dim=32, |
| | c_mults = [1, 2, 4, 8], |
| | strides = [2, 4, 8, 8], |
| | use_snake=False, |
| | antialias_activation=False |
| | ): |
| | super().__init__() |
| | |
| | c_mults = [1] + c_mults |
| |
|
| | self.depth = len(c_mults) |
| |
|
| | layers = [ |
| | WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3) |
| | ] |
| | |
| | for i in range(self.depth-1): |
| | layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)] |
| |
|
| | layers += [ |
| | get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels), |
| | WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1) |
| | ] |
| |
|
| | self.layers = nn.Sequential(*layers) |
| |
|
| | def forward(self, x): |
| | return self.layers(x) |
| |
|
| |
|
| | class OobleckDecoder(nn.Module): |
| | def __init__(self, |
| | out_channels=2, |
| | channels=128, |
| | latent_dim=32, |
| | c_mults = [1, 2, 4, 8], |
| | strides = [2, 4, 8, 8], |
| | use_snake=False, |
| | antialias_activation=False, |
| | use_nearest_upsample=False, |
| | final_tanh=True): |
| | super().__init__() |
| |
|
| | c_mults = [1] + c_mults |
| | |
| | self.depth = len(c_mults) |
| |
|
| | layers = [ |
| | WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3), |
| | ] |
| | |
| | for i in range(self.depth-1, 0, -1): |
| | layers += [DecoderBlock( |
| | in_channels=c_mults[i]*channels, |
| | out_channels=c_mults[i-1]*channels, |
| | stride=strides[i-1], |
| | use_snake=use_snake, |
| | antialias_activation=antialias_activation, |
| | use_nearest_upsample=use_nearest_upsample |
| | ) |
| | ] |
| |
|
| | layers += [ |
| | get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels), |
| | WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False), |
| | nn.Tanh() if final_tanh else nn.Identity() |
| | ] |
| |
|
| | self.layers = nn.Sequential(*layers) |
| |
|
| | def forward(self, x): |
| | return self.layers(x) |
| |
|
| |
|
| | class DACEncoderWrapper(nn.Module): |
| | def __init__(self, in_channels=1, **kwargs): |
| | super().__init__() |
| |
|
| | from dac.model.dac import Encoder as DACEncoder |
| |
|
| | latent_dim = kwargs.pop("latent_dim", None) |
| |
|
| | encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"])) |
| | self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs) |
| | self.latent_dim = latent_dim |
| |
|
| | |
| | self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity() |
| |
|
| | if in_channels != 1: |
| | self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3) |
| |
|
| | def forward(self, x): |
| | x = self.encoder(x) |
| | x = self.proj_out(x) |
| | return x |
| |
|
| | class DACDecoderWrapper(nn.Module): |
| | def __init__(self, latent_dim, out_channels=1, **kwargs): |
| | super().__init__() |
| |
|
| | from dac.model.dac import Decoder as DACDecoder |
| |
|
| | self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels) |
| |
|
| | self.latent_dim = latent_dim |
| |
|
| | def forward(self, x): |
| | return self.decoder(x) |
| |
|
| | class AudioAutoencoder(nn.Module): |
| | def __init__( |
| | self, |
| | encoder, |
| | decoder, |
| | latent_dim, |
| | downsampling_ratio, |
| | sample_rate, |
| | io_channels=2, |
| | bottleneck: Bottleneck = None, |
| | pretransform: Pretransform = None, |
| | in_channels = None, |
| | out_channels = None, |
| | soft_clip = False |
| | ): |
| | super().__init__() |
| |
|
| | self.downsampling_ratio = downsampling_ratio |
| | self.sample_rate = sample_rate |
| |
|
| | self.latent_dim = latent_dim |
| | self.io_channels = io_channels |
| | self.in_channels = io_channels |
| | self.out_channels = io_channels |
| |
|
| | self.min_length = self.downsampling_ratio |
| |
|
| | if in_channels is not None: |
| | self.in_channels = in_channels |
| |
|
| | if out_channels is not None: |
| | self.out_channels = out_channels |
| |
|
| | self.bottleneck = bottleneck |
| |
|
| | self.encoder = encoder |
| |
|
| | self.decoder = decoder |
| |
|
| | self.pretransform = pretransform |
| |
|
| | self.soft_clip = soft_clip |
| | |
| | self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete |
| |
|
| | def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs): |
| |
|
| | info = {} |
| |
|
| | if self.pretransform is not None and not skip_pretransform: |
| | if self.pretransform.enable_grad: |
| | if iterate_batch: |
| | audios = [] |
| | for i in range(audio.shape[0]): |
| | audios.append(self.pretransform.encode(audio[i:i+1])) |
| | audio = torch.cat(audios, dim=0) |
| | else: |
| | audio = self.pretransform.encode(audio) |
| | else: |
| | with torch.no_grad(): |
| | if iterate_batch: |
| | audios = [] |
| | for i in range(audio.shape[0]): |
| | audios.append(self.pretransform.encode(audio[i:i+1])) |
| | audio = torch.cat(audios, dim=0) |
| | else: |
| | audio = self.pretransform.encode(audio) |
| |
|
| | if self.encoder is not None: |
| | if iterate_batch: |
| | latents = [] |
| | for i in range(audio.shape[0]): |
| | latents.append(self.encoder(audio[i:i+1])) |
| | latents = torch.cat(latents, dim=0) |
| | else: |
| | latents = self.encoder(audio) |
| | else: |
| | latents = audio |
| |
|
| | if self.bottleneck is not None: |
| | |
| | latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs) |
| |
|
| | info.update(bottleneck_info) |
| |
|
| | if return_info: |
| | return latents, info |
| |
|
| | return latents |
| |
|
| | def decode(self, latents, iterate_batch=False, **kwargs): |
| |
|
| | if self.bottleneck is not None: |
| | if iterate_batch: |
| | decoded = [] |
| | for i in range(latents.shape[0]): |
| | decoded.append(self.bottleneck.decode(latents[i:i+1])) |
| | decoded = torch.cat(decoded, dim=0) |
| | else: |
| | latents = self.bottleneck.decode(latents) |
| |
|
| | if iterate_batch: |
| | decoded = [] |
| | for i in range(latents.shape[0]): |
| | decoded.append(self.decoder(latents[i:i+1])) |
| | decoded = torch.cat(decoded, dim=0) |
| | else: |
| | decoded = self.decoder(latents, **kwargs) |
| |
|
| | if self.pretransform is not None: |
| | if self.pretransform.enable_grad: |
| | if iterate_batch: |
| | decodeds = [] |
| | for i in range(decoded.shape[0]): |
| | decodeds.append(self.pretransform.decode(decoded[i:i+1])) |
| | decoded = torch.cat(decodeds, dim=0) |
| | else: |
| | decoded = self.pretransform.decode(decoded) |
| | else: |
| | with torch.no_grad(): |
| | if iterate_batch: |
| | decodeds = [] |
| | for i in range(latents.shape[0]): |
| | decodeds.append(self.pretransform.decode(decoded[i:i+1])) |
| | decoded = torch.cat(decodeds, dim=0) |
| | else: |
| | decoded = self.pretransform.decode(decoded) |
| |
|
| | if self.soft_clip: |
| | decoded = torch.tanh(decoded) |
| |
|
| | return decoded |
| |
|
| | def decode_tokens(self, tokens, **kwargs): |
| | ''' |
| | Decode discrete tokens to audio |
| | Only works with discrete autoencoders |
| | ''' |
| |
|
| | assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders" |
| |
|
| | latents = self.bottleneck.decode_tokens(tokens, **kwargs) |
| |
|
| | return self.decode(latents, **kwargs) |
| | |
| | |
| | def preprocess_audio_for_encoder(self, audio, in_sr): |
| | ''' |
| | Preprocess single audio tensor (Channels x Length) to be compatible with the encoder. |
| | If the model is mono, stereo audio will be converted to mono. |
| | Audio will be silence-padded to be a multiple of the model's downsampling ratio. |
| | Audio will be resampled to the model's sample rate. |
| | The output will have batch size 1 and be shape (1 x Channels x Length) |
| | ''' |
| | return self.preprocess_audio_list_for_encoder([audio], [in_sr]) |
| |
|
| | def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list): |
| | ''' |
| | Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder. |
| | The audio in that list can be of different lengths and channels. |
| | in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio. |
| | All audio will be resampled to the model's sample rate. |
| | Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio. |
| | If the model is mono, all audio will be converted to mono. |
| | The output will be a tensor of shape (Batch x Channels x Length) |
| | ''' |
| | batch_size = len(audio_list) |
| | if isinstance(in_sr_list, int): |
| | in_sr_list = [in_sr_list]*batch_size |
| | assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list" |
| | new_audio = [] |
| | max_length = 0 |
| | |
| | for i in range(batch_size): |
| | audio = audio_list[i] |
| | in_sr = in_sr_list[i] |
| | if len(audio.shape) == 3 and audio.shape[0] == 1: |
| | |
| | audio = audio.squeeze(0) |
| | elif len(audio.shape) == 1: |
| | |
| | audio = audio.unsqueeze(0) |
| | assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension" |
| | |
| | if in_sr != self.sample_rate: |
| | resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device) |
| | audio = resample_tf(audio) |
| | new_audio.append(audio) |
| | if audio.shape[-1] > max_length: |
| | max_length = audio.shape[-1] |
| | |
| | padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length |
| | for i in range(batch_size): |
| | |
| | new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length, |
| | target_channels=self.in_channels, device=new_audio[i].device).squeeze(0) |
| | |
| | return torch.stack(new_audio) |
| |
|
| | def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs): |
| | ''' |
| | Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder. |
| | If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap. |
| | Overlap and chunk_size params are both measured in number of latents (not audio samples) |
| | # and therefore you likely could use the same values with decode_audio. |
| | A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. |
| | Every autoencoder will have a different receptive field size, and thus ideal overlap. |
| | You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff. |
| | The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks. |
| | Smaller chunk_size uses less memory, but more compute. |
| | The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version |
| | For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks |
| | ''' |
| | if not chunked: |
| | |
| | return self.encode(audio, **kwargs) |
| | else: |
| | |
| | |
| | samples_per_latent = self.downsampling_ratio |
| | total_size = audio.shape[2] |
| | batch_size = audio.shape[0] |
| | chunk_size *= samples_per_latent |
| | overlap *= samples_per_latent |
| | hop_size = chunk_size - overlap |
| | chunks = [] |
| | for i in range(0, total_size - chunk_size + 1, hop_size): |
| | chunk = audio[:,:,i:i+chunk_size] |
| | chunks.append(chunk) |
| | if i+chunk_size != total_size: |
| | |
| | chunk = audio[:,:,-chunk_size:] |
| | chunks.append(chunk) |
| | chunks = torch.stack(chunks) |
| | num_chunks = chunks.shape[0] |
| | |
| | |
| | |
| | y_size = total_size // samples_per_latent |
| | |
| | y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device) |
| | for i in range(num_chunks): |
| | x_chunk = chunks[i,:] |
| | |
| | y_chunk = self.encode(x_chunk) |
| | |
| | if i == num_chunks-1: |
| | |
| | t_end = y_size |
| | t_start = t_end - y_chunk.shape[2] |
| | else: |
| | t_start = i * hop_size // samples_per_latent |
| | t_end = t_start + chunk_size // samples_per_latent |
| | |
| | ol = overlap//samples_per_latent//2 |
| | chunk_start = 0 |
| | chunk_end = y_chunk.shape[2] |
| | if i > 0: |
| | |
| | t_start += ol |
| | chunk_start += ol |
| | if i < num_chunks-1: |
| | |
| | t_end -= ol |
| | chunk_end -= ol |
| | |
| | y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] |
| | return y_final |
| | |
| | def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs): |
| | ''' |
| | Decode latents to audio. |
| | If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents. |
| | A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. |
| | Every autoencoder will have a different receptive field size, and thus ideal overlap. |
| | You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff. |
| | The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks. |
| | Smaller chunk_size uses less memory, but more compute. |
| | The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version |
| | For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks |
| | ''' |
| | if not chunked: |
| | |
| | return self.decode(latents, **kwargs) |
| | else: |
| | |
| | hop_size = chunk_size - overlap |
| | total_size = latents.shape[2] |
| | batch_size = latents.shape[0] |
| | chunks = [] |
| | for i in range(0, total_size - chunk_size + 1, hop_size): |
| | chunk = latents[:,:,i:i+chunk_size] |
| | chunks.append(chunk) |
| | if i+chunk_size != total_size: |
| | |
| | chunk = latents[:,:,-chunk_size:] |
| | chunks.append(chunk) |
| | chunks = torch.stack(chunks) |
| | num_chunks = chunks.shape[0] |
| | |
| | samples_per_latent = self.downsampling_ratio |
| | |
| | y_size = total_size * samples_per_latent |
| | y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device) |
| | for i in range(num_chunks): |
| | x_chunk = chunks[i,:] |
| | |
| | y_chunk = self.decode(x_chunk) |
| | |
| | if i == num_chunks-1: |
| | |
| | t_end = y_size |
| | t_start = t_end - y_chunk.shape[2] |
| | else: |
| | t_start = i * hop_size * samples_per_latent |
| | t_end = t_start + chunk_size * samples_per_latent |
| | |
| | ol = (overlap//2) * samples_per_latent |
| | chunk_start = 0 |
| | chunk_end = y_chunk.shape[2] |
| | if i > 0: |
| | |
| | t_start += ol |
| | chunk_start += ol |
| | if i < num_chunks-1: |
| | |
| | t_end -= ol |
| | chunk_end -= ol |
| | |
| | y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] |
| | return y_final |
| |
|
| | |
| | |
| |
|
| | def create_encoder_from_config(encoder_config: Dict[str, Any]): |
| | encoder_type = encoder_config.get("type", None) |
| | assert encoder_type is not None, "Encoder type must be specified" |
| |
|
| | if encoder_type == "oobleck": |
| | encoder = OobleckEncoder( |
| | **encoder_config["config"] |
| | ) |
| | |
| | elif encoder_type == "seanet": |
| | from encodec.modules import SEANetEncoder |
| | seanet_encoder_config = encoder_config["config"] |
| |
|
| | |
| | seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2]))) |
| | encoder = SEANetEncoder( |
| | **seanet_encoder_config |
| | ) |
| | elif encoder_type == "dac": |
| | dac_config = encoder_config["config"] |
| |
|
| | encoder = DACEncoderWrapper(**dac_config) |
| | elif encoder_type == "local_attn": |
| | from .local_attention import TransformerEncoder1D |
| |
|
| | local_attn_config = encoder_config["config"] |
| |
|
| | encoder = TransformerEncoder1D( |
| | **local_attn_config |
| | ) |
| | else: |
| | raise ValueError(f"Unknown encoder type {encoder_type}") |
| | |
| | requires_grad = encoder_config.get("requires_grad", True) |
| | if not requires_grad: |
| | for param in encoder.parameters(): |
| | param.requires_grad = False |
| |
|
| | return encoder |
| |
|
| | def create_decoder_from_config(decoder_config: Dict[str, Any]): |
| | decoder_type = decoder_config.get("type", None) |
| | assert decoder_type is not None, "Decoder type must be specified" |
| |
|
| | if decoder_type == "oobleck": |
| | decoder = OobleckDecoder( |
| | **decoder_config["config"] |
| | ) |
| | elif decoder_type == "seanet": |
| | from encodec.modules import SEANetDecoder |
| |
|
| | decoder = SEANetDecoder( |
| | **decoder_config["config"] |
| | ) |
| | elif decoder_type == "dac": |
| | dac_config = decoder_config["config"] |
| |
|
| | decoder = DACDecoderWrapper(**dac_config) |
| | elif decoder_type == "local_attn": |
| | from .local_attention import TransformerDecoder1D |
| |
|
| | local_attn_config = decoder_config["config"] |
| |
|
| | decoder = TransformerDecoder1D( |
| | **local_attn_config |
| | ) |
| | else: |
| | raise ValueError(f"Unknown decoder type {decoder_type}") |
| | |
| | requires_grad = decoder_config.get("requires_grad", True) |
| | if not requires_grad: |
| | for param in decoder.parameters(): |
| | param.requires_grad = False |
| |
|
| | return decoder |
| |
|
| | def create_autoencoder_from_config(config: Dict[str, Any]): |
| | |
| | ae_config = config["model"] |
| |
|
| | encoder = create_encoder_from_config(ae_config["encoder"]) |
| | decoder = create_decoder_from_config(ae_config["decoder"]) |
| |
|
| | bottleneck = ae_config.get("bottleneck", None) |
| |
|
| | latent_dim = ae_config.get("latent_dim", None) |
| | assert latent_dim is not None, "latent_dim must be specified in model config" |
| | downsampling_ratio = ae_config.get("downsampling_ratio", None) |
| | assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" |
| | io_channels = ae_config.get("io_channels", None) |
| | assert io_channels is not None, "io_channels must be specified in model config" |
| | sample_rate = config.get("sample_rate", None) |
| | assert sample_rate is not None, "sample_rate must be specified in model config" |
| |
|
| | in_channels = ae_config.get("in_channels", None) |
| | out_channels = ae_config.get("out_channels", None) |
| |
|
| | pretransform = ae_config.get("pretransform", None) |
| |
|
| | if pretransform is not None: |
| | pretransform = create_pretransform_from_config(pretransform, sample_rate) |
| |
|
| | if bottleneck is not None: |
| | bottleneck = create_bottleneck_from_config(bottleneck) |
| |
|
| | soft_clip = ae_config["decoder"].get("soft_clip", False) |
| |
|
| | return AudioAutoencoder( |
| | encoder, |
| | decoder, |
| | io_channels=io_channels, |
| | latent_dim=latent_dim, |
| | downsampling_ratio=downsampling_ratio, |
| | sample_rate=sample_rate, |
| | bottleneck=bottleneck, |
| | pretransform=pretransform, |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | soft_clip=soft_clip |
| | ) |