| | import math
|
| | from dataclasses import dataclass
|
| | from pathlib import Path
|
| | from typing import Union
|
| |
|
| | import numpy as np
|
| | import torch
|
| | import tqdm
|
| | from audiotools import AudioSignal
|
| | from torch import nn
|
| |
|
| | SUPPORTED_VERSIONS = ["1.0.0"]
|
| |
|
| |
|
| | @dataclass
|
| | class DACFile:
|
| | codes: torch.Tensor
|
| |
|
| |
|
| | chunk_length: int
|
| | original_length: int
|
| | input_db: float
|
| | channels: int
|
| | sample_rate: int
|
| | padding: bool
|
| | dac_version: str
|
| |
|
| | def save(self, path):
|
| | artifacts = {
|
| | "codes": self.codes.numpy().astype(np.uint16),
|
| | "metadata": {
|
| | "input_db": self.input_db.numpy().astype(np.float32),
|
| | "original_length": self.original_length,
|
| | "sample_rate": self.sample_rate,
|
| | "chunk_length": self.chunk_length,
|
| | "channels": self.channels,
|
| | "padding": self.padding,
|
| | "dac_version": SUPPORTED_VERSIONS[-1],
|
| | },
|
| | }
|
| | path = Path(path).with_suffix(".dac")
|
| | with open(path, "wb") as f:
|
| | np.save(f, artifacts)
|
| | return path
|
| |
|
| | @classmethod
|
| | def load(cls, path):
|
| | artifacts = np.load(path, allow_pickle=True)[()]
|
| | codes = torch.from_numpy(artifacts["codes"].astype(int))
|
| | if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
|
| | raise RuntimeError(
|
| | f"Given file {path} can't be loaded with this version of descript-audio-codec."
|
| | )
|
| | return cls(codes=codes, **artifacts["metadata"])
|
| |
|
| |
|
| | class CodecMixin:
|
| | @property
|
| | def padding(self):
|
| | if not hasattr(self, "_padding"):
|
| | self._padding = True
|
| | return self._padding
|
| |
|
| | @padding.setter
|
| | def padding(self, value):
|
| | assert isinstance(value, bool)
|
| |
|
| | layers = [
|
| | l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))
|
| | ]
|
| |
|
| | for layer in layers:
|
| | if value:
|
| | if hasattr(layer, "original_padding"):
|
| | layer.padding = layer.original_padding
|
| | else:
|
| | layer.original_padding = layer.padding
|
| | layer.padding = tuple(0 for _ in range(len(layer.padding)))
|
| |
|
| | self._padding = value
|
| |
|
| | def get_delay(self):
|
| |
|
| | l_out = self.get_output_length(0)
|
| | L = l_out
|
| |
|
| | layers = []
|
| | for layer in self.modules():
|
| | if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
|
| | layers.append(layer)
|
| |
|
| | for layer in reversed(layers):
|
| | d = layer.dilation[0]
|
| | k = layer.kernel_size[0]
|
| | s = layer.stride[0]
|
| |
|
| | if isinstance(layer, nn.ConvTranspose1d):
|
| | L = ((L - d * (k - 1) - 1) / s) + 1
|
| | elif isinstance(layer, nn.Conv1d):
|
| | L = (L - 1) * s + d * (k - 1) + 1
|
| |
|
| | L = math.ceil(L)
|
| |
|
| | l_in = L
|
| |
|
| | return (l_in - l_out) // 2
|
| |
|
| | def get_output_length(self, input_length):
|
| | L = input_length
|
| |
|
| | for layer in self.modules():
|
| | if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
|
| | d = layer.dilation[0]
|
| | k = layer.kernel_size[0]
|
| | s = layer.stride[0]
|
| |
|
| | if isinstance(layer, nn.Conv1d):
|
| | L = ((L - d * (k - 1) - 1) / s) + 1
|
| | elif isinstance(layer, nn.ConvTranspose1d):
|
| | L = (L - 1) * s + d * (k - 1) + 1
|
| |
|
| | L = math.floor(L)
|
| | return L
|
| |
|
| | @torch.no_grad()
|
| | def compress(
|
| | self,
|
| | audio_path_or_signal: Union[str, Path, AudioSignal],
|
| | win_duration: float = 1.0,
|
| | verbose: bool = False,
|
| | normalize_db: float = -16,
|
| | n_quantizers: int = None,
|
| | ) -> DACFile:
|
| | """Processes an audio signal from a file or AudioSignal object into
|
| | discrete codes. This function processes the signal in short windows,
|
| | using constant GPU memory.
|
| |
|
| | Parameters
|
| | ----------
|
| | audio_path_or_signal : Union[str, Path, AudioSignal]
|
| | audio signal to reconstruct
|
| | win_duration : float, optional
|
| | window duration in seconds, by default 5.0
|
| | verbose : bool, optional
|
| | by default False
|
| | normalize_db : float, optional
|
| | normalize db, by default -16
|
| |
|
| | Returns
|
| | -------
|
| | DACFile
|
| | Object containing compressed codes and metadata
|
| | required for decompression
|
| | """
|
| | audio_signal = audio_path_or_signal
|
| | if isinstance(audio_signal, (str, Path)):
|
| | audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
|
| |
|
| | self.eval()
|
| | original_padding = self.padding
|
| | original_device = audio_signal.device
|
| |
|
| | audio_signal = audio_signal.clone()
|
| | original_sr = audio_signal.sample_rate
|
| |
|
| | resample_fn = audio_signal.resample
|
| | loudness_fn = audio_signal.loudness
|
| |
|
| |
|
| | if audio_signal.signal_duration >= 10 * 60 * 60:
|
| | resample_fn = audio_signal.ffmpeg_resample
|
| | loudness_fn = audio_signal.ffmpeg_loudness
|
| |
|
| | original_length = audio_signal.signal_length
|
| | resample_fn(self.sample_rate)
|
| | input_db = loudness_fn()
|
| |
|
| | if normalize_db is not None:
|
| | audio_signal.normalize(normalize_db)
|
| | audio_signal.ensure_max_of_audio()
|
| |
|
| | nb, nac, nt = audio_signal.audio_data.shape
|
| | audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
|
| | win_duration = (
|
| | audio_signal.signal_duration if win_duration is None else win_duration
|
| | )
|
| |
|
| | if audio_signal.signal_duration <= win_duration:
|
| |
|
| | self.padding = True
|
| | n_samples = nt
|
| | hop = nt
|
| | else:
|
| |
|
| | self.padding = False
|
| |
|
| | audio_signal.zero_pad(self.delay, self.delay)
|
| | n_samples = int(win_duration * self.sample_rate)
|
| |
|
| | n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
|
| | hop = self.get_output_length(n_samples)
|
| |
|
| | codes = []
|
| | range_fn = range if not verbose else tqdm.trange
|
| |
|
| | for i in range_fn(0, nt, hop):
|
| | x = audio_signal[..., i : i + n_samples]
|
| | x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
|
| |
|
| | audio_data = x.audio_data.to(self.device)
|
| | audio_data = self.preprocess(audio_data, self.sample_rate)
|
| | _, c, _, _, _ = self.encode(audio_data, n_quantizers)
|
| | codes.append(c.to(original_device))
|
| | chunk_length = c.shape[-1]
|
| |
|
| | codes = torch.cat(codes, dim=-1)
|
| |
|
| | dac_file = DACFile(
|
| | codes=codes,
|
| | chunk_length=chunk_length,
|
| | original_length=original_length,
|
| | input_db=input_db,
|
| | channels=nac,
|
| | sample_rate=original_sr,
|
| | padding=self.padding,
|
| | dac_version=SUPPORTED_VERSIONS[-1],
|
| | )
|
| |
|
| | if n_quantizers is not None:
|
| | codes = codes[:, :n_quantizers, :]
|
| |
|
| | self.padding = original_padding
|
| | return dac_file
|
| |
|
| | @torch.no_grad()
|
| | def decompress(
|
| | self,
|
| | obj: Union[str, Path, DACFile],
|
| | verbose: bool = False,
|
| | ) -> AudioSignal:
|
| | """Reconstruct audio from a given .dac file
|
| |
|
| | Parameters
|
| | ----------
|
| | obj : Union[str, Path, DACFile]
|
| | .dac file location or corresponding DACFile object.
|
| | verbose : bool, optional
|
| | Prints progress if True, by default False
|
| |
|
| | Returns
|
| | -------
|
| | AudioSignal
|
| | Object with the reconstructed audio
|
| | """
|
| | self.eval()
|
| | if isinstance(obj, (str, Path)):
|
| | obj = DACFile.load(obj)
|
| |
|
| | original_padding = self.padding
|
| | self.padding = obj.padding
|
| |
|
| | range_fn = range if not verbose else tqdm.trange
|
| | codes = obj.codes
|
| | original_device = codes.device
|
| | chunk_length = obj.chunk_length
|
| | recons = []
|
| |
|
| | for i in range_fn(0, codes.shape[-1], chunk_length):
|
| | c = codes[..., i : i + chunk_length].to(self.device)
|
| | z = self.quantizer.from_codes(c)[0]
|
| | r = self.decode(z)
|
| | recons.append(r.to(original_device))
|
| |
|
| | recons = torch.cat(recons, dim=-1)
|
| | recons = AudioSignal(recons, self.sample_rate)
|
| |
|
| | resample_fn = recons.resample
|
| | loudness_fn = recons.loudness
|
| |
|
| |
|
| | if recons.signal_duration >= 10 * 60 * 60:
|
| | resample_fn = recons.ffmpeg_resample
|
| | loudness_fn = recons.ffmpeg_loudness
|
| |
|
| | recons.normalize(obj.input_db)
|
| | resample_fn(obj.sample_rate)
|
| | recons = recons[..., : obj.original_length]
|
| | loudness_fn()
|
| | recons.audio_data = recons.audio_data.reshape(
|
| | -1, obj.channels, obj.original_length
|
| | )
|
| |
|
| | self.padding = original_padding
|
| | return recons
|
| |
|