| from pathlib import Path |
| from typing import Callable |
| from typing import Dict |
| from typing import List |
| from typing import Union |
|
|
| import numpy as np |
| from torch.utils.data import SequentialSampler |
| from torch.utils.data.distributed import DistributedSampler |
|
|
| from ..core import AudioSignal |
| from ..core import util |
|
|
|
|
| class AudioLoader: |
| """Loads audio endlessly from a list of audio sources |
| containing paths to audio files. Audio sources can be |
| folders full of audio files (which are found via file |
| extension) or by providing a CSV file which contains paths |
| to audio files. |
| |
| Parameters |
| ---------- |
| sources : List[str], optional |
| Sources containing folders, or CSVs with |
| paths to audio files, by default None |
| weights : List[float], optional |
| Weights to sample audio files from each source, by default None |
| relative_path : str, optional |
| Path audio should be loaded relative to, by default "" |
| transform : Callable, optional |
| Transform to instantiate alongside audio sample, |
| by default None |
| ext : List[str] |
| List of extensions to find audio within each source by. Can |
| also be a file name (e.g. "vocals.wav"). by default |
| ``['.wav', '.flac', '.mp3', '.mp4']``. |
| shuffle: bool |
| Whether to shuffle the files within the dataloader. Defaults to True. |
| shuffle_state: int |
| State to use to seed the shuffle of the files. |
| """ |
|
|
| def __init__( |
| self, |
| sources: List[str] = None, |
| weights: List[float] = None, |
| transform: Callable = None, |
| relative_path: str = "", |
| ext: List[str] = util.AUDIO_EXTENSIONS, |
| shuffle: bool = True, |
| shuffle_state: int = 0, |
| ): |
| self.audio_lists = util.read_sources( |
| sources, relative_path=relative_path, ext=ext |
| ) |
|
|
| self.audio_indices = [ |
| (src_idx, item_idx) |
| for src_idx, src in enumerate(self.audio_lists) |
| for item_idx in range(len(src)) |
| ] |
| if shuffle: |
| state = util.random_state(shuffle_state) |
| state.shuffle(self.audio_indices) |
|
|
| self.sources = sources |
| self.weights = weights |
| self.transform = transform |
|
|
| def __call__( |
| self, |
| state, |
| sample_rate: int, |
| duration: float, |
| loudness_cutoff: float = -40, |
| num_channels: int = 1, |
| offset: float = None, |
| source_idx: int = None, |
| item_idx: int = None, |
| global_idx: int = None, |
| ): |
| if source_idx is not None and item_idx is not None: |
| try: |
| audio_info = self.audio_lists[source_idx][item_idx] |
| except: |
| audio_info = {"path": "none"} |
| elif global_idx is not None: |
| source_idx, item_idx = self.audio_indices[ |
| global_idx % len(self.audio_indices) |
| ] |
| audio_info = self.audio_lists[source_idx][item_idx] |
| else: |
| audio_info, source_idx, item_idx = util.choose_from_list_of_lists( |
| state, self.audio_lists, p=self.weights |
| ) |
|
|
| path = audio_info["path"] |
| signal = AudioSignal.zeros(duration, sample_rate, num_channels) |
|
|
| if path != "none": |
| if offset is None: |
| signal = AudioSignal.salient_excerpt( |
| path, |
| duration=duration, |
| state=state, |
| loudness_cutoff=loudness_cutoff, |
| ) |
| else: |
| signal = AudioSignal( |
| path, |
| offset=offset, |
| duration=duration, |
| ) |
|
|
| if num_channels == 1: |
| signal = signal.to_mono() |
| signal = signal.resample(sample_rate) |
|
|
| if signal.duration < duration: |
| signal = signal.zero_pad_to(int(duration * sample_rate)) |
|
|
| for k, v in audio_info.items(): |
| signal.metadata[k] = v |
|
|
| item = { |
| "signal": signal, |
| "source_idx": source_idx, |
| "item_idx": item_idx, |
| "source": str(self.sources[source_idx]), |
| "path": str(path), |
| } |
| if self.transform is not None: |
| item["transform_args"] = self.transform.instantiate(state, signal=signal) |
| return item |
|
|
|
|
| def default_matcher(x, y): |
| return Path(x).parent == Path(y).parent |
|
|
|
|
| def align_lists(lists, matcher: Callable = default_matcher): |
| longest_list = lists[np.argmax([len(l) for l in lists])] |
| for i, x in enumerate(longest_list): |
| for l in lists: |
| if i >= len(l): |
| l.append({"path": "none"}) |
| elif not matcher(l[i]["path"], x["path"]): |
| l.insert(i, {"path": "none"}) |
| return lists |
|
|
|
|
| class AudioDataset: |
| """Loads audio from multiple loaders (with associated transforms) |
| for a specified number of samples. Excerpts are drawn randomly |
| of the specified duration, above a specified loudness threshold |
| and are resampled on the fly to the desired sample rate |
| (if it is different from the audio source sample rate). |
| |
| This takes either a single AudioLoader object, |
| a dictionary of AudioLoader objects, or a dictionary of AudioLoader |
| objects. Each AudioLoader is called by the dataset, and the |
| result is placed in the output dictionary. A transform can also be |
| specified for the entire dataset, rather than for each specific |
| loader. This transform can be applied to the output of all the |
| loaders if desired. |
| |
| AudioLoader objects can be specified as aligned, which means the |
| loaders correspond to multitrack audio (e.g. a vocals, bass, |
| drums, and other loader for multitrack music mixtures). |
| |
| |
| Parameters |
| ---------- |
| loaders : Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]] |
| AudioLoaders to sample audio from. |
| sample_rate : int |
| Desired sample rate. |
| n_examples : int, optional |
| Number of examples (length of dataset), by default 1000 |
| duration : float, optional |
| Duration of audio samples, by default 0.5 |
| loudness_cutoff : float, optional |
| Loudness cutoff threshold for audio samples, by default -40 |
| num_channels : int, optional |
| Number of channels in output audio, by default 1 |
| transform : Callable, optional |
| Transform to instantiate alongside each dataset item, by default None |
| aligned : bool, optional |
| Whether the loaders should be sampled in an aligned manner (e.g. same |
| offset, duration, and matched file name), by default False |
| shuffle_loaders : bool, optional |
| Whether to shuffle the loaders before sampling from them, by default False |
| matcher : Callable |
| How to match files from adjacent audio lists (e.g. for a multitrack audio loader), |
| by default uses the parent directory of each file. |
| without_replacement : bool |
| Whether to choose files with or without replacement, by default True. |
| |
| |
| Examples |
| -------- |
| >>> from audiotools.data.datasets import AudioLoader |
| >>> from audiotools.data.datasets import AudioDataset |
| >>> from audiotools import transforms as tfm |
| >>> import numpy as np |
| >>> |
| >>> loaders = [ |
| >>> AudioLoader( |
| >>> sources=[f"tests/audio/spk"], |
| >>> transform=tfm.Equalizer(), |
| >>> ext=["wav"], |
| >>> ) |
| >>> for i in range(5) |
| >>> ] |
| >>> |
| >>> dataset = AudioDataset( |
| >>> loaders = loaders, |
| >>> sample_rate = 44100, |
| >>> duration = 1.0, |
| >>> transform = tfm.RescaleAudio(), |
| >>> ) |
| >>> |
| >>> item = dataset[np.random.randint(len(dataset))] |
| >>> |
| >>> for i in range(len(loaders)): |
| >>> item[i]["signal"] = loaders[i].transform( |
| >>> item[i]["signal"], **item[i]["transform_args"] |
| >>> ) |
| >>> item[i]["signal"].widget(i) |
| >>> |
| >>> mix = sum([item[i]["signal"] for i in range(len(loaders))]) |
| >>> mix = dataset.transform(mix, **item["transform_args"]) |
| >>> mix.widget("mix") |
| |
| Below is an example of how one could load MUSDB multitrack data: |
| |
| >>> import audiotools as at |
| >>> from pathlib import Path |
| >>> from audiotools import transforms as tfm |
| >>> import numpy as np |
| >>> import torch |
| >>> |
| >>> def build_dataset( |
| >>> sample_rate: int = 44100, |
| >>> duration: float = 5.0, |
| >>> musdb_path: str = "~/.data/musdb/", |
| >>> ): |
| >>> musdb_path = Path(musdb_path).expanduser() |
| >>> loaders = { |
| >>> src: at.datasets.AudioLoader( |
| >>> sources=[musdb_path], |
| >>> transform=tfm.Compose( |
| >>> tfm.VolumeNorm(("uniform", -20, -10)), |
| >>> tfm.Silence(prob=0.1), |
| >>> ), |
| >>> ext=[f"{src}.wav"], |
| >>> ) |
| >>> for src in ["vocals", "bass", "drums", "other"] |
| >>> } |
| >>> |
| >>> dataset = at.datasets.AudioDataset( |
| >>> loaders=loaders, |
| >>> sample_rate=sample_rate, |
| >>> duration=duration, |
| >>> num_channels=1, |
| >>> aligned=True, |
| >>> transform=tfm.RescaleAudio(), |
| >>> shuffle_loaders=True, |
| >>> ) |
| >>> return dataset, list(loaders.keys()) |
| >>> |
| >>> train_data, sources = build_dataset() |
| >>> dataloader = torch.utils.data.DataLoader( |
| >>> train_data, |
| >>> batch_size=16, |
| >>> num_workers=0, |
| >>> collate_fn=train_data.collate, |
| >>> ) |
| >>> batch = next(iter(dataloader)) |
| >>> |
| >>> for k in sources: |
| >>> src = batch[k] |
| >>> src["transformed"] = train_data.loaders[k].transform( |
| >>> src["signal"].clone(), **src["transform_args"] |
| >>> ) |
| >>> |
| >>> mixture = sum(batch[k]["transformed"] for k in sources) |
| >>> mixture = train_data.transform(mixture, **batch["transform_args"]) |
| >>> |
| >>> # Say a model takes the mix and gives back (n_batch, n_src, n_time). |
| >>> # Construct the targets: |
| >>> targets = at.AudioSignal.batch([batch[k]["transformed"] for k in sources], dim=1) |
| |
| Similarly, here's example code for loading Slakh data: |
| |
| >>> import audiotools as at |
| >>> from pathlib import Path |
| >>> from audiotools import transforms as tfm |
| >>> import numpy as np |
| >>> import torch |
| >>> import glob |
| >>> |
| >>> def build_dataset( |
| >>> sample_rate: int = 16000, |
| >>> duration: float = 10.0, |
| >>> slakh_path: str = "~/.data/slakh/", |
| >>> ): |
| >>> slakh_path = Path(slakh_path).expanduser() |
| >>> |
| >>> # Find the max number of sources in Slakh |
| >>> src_names = [x.name for x in list(slakh_path.glob("**/*.wav")) if "S" in str(x.name)] |
| >>> n_sources = len(list(set(src_names))) |
| >>> |
| >>> loaders = { |
| >>> f"S{i:02d}": at.datasets.AudioLoader( |
| >>> sources=[slakh_path], |
| >>> transform=tfm.Compose( |
| >>> tfm.VolumeNorm(("uniform", -20, -10)), |
| >>> tfm.Silence(prob=0.1), |
| >>> ), |
| >>> ext=[f"S{i:02d}.wav"], |
| >>> ) |
| >>> for i in range(n_sources) |
| >>> } |
| >>> dataset = at.datasets.AudioDataset( |
| >>> loaders=loaders, |
| >>> sample_rate=sample_rate, |
| >>> duration=duration, |
| >>> num_channels=1, |
| >>> aligned=True, |
| >>> transform=tfm.RescaleAudio(), |
| >>> shuffle_loaders=False, |
| >>> ) |
| >>> |
| >>> return dataset, list(loaders.keys()) |
| >>> |
| >>> train_data, sources = build_dataset() |
| >>> dataloader = torch.utils.data.DataLoader( |
| >>> train_data, |
| >>> batch_size=16, |
| >>> num_workers=0, |
| >>> collate_fn=train_data.collate, |
| >>> ) |
| >>> batch = next(iter(dataloader)) |
| >>> |
| >>> for k in sources: |
| >>> src = batch[k] |
| >>> src["transformed"] = train_data.loaders[k].transform( |
| >>> src["signal"].clone(), **src["transform_args"] |
| >>> ) |
| >>> |
| >>> mixture = sum(batch[k]["transformed"] for k in sources) |
| >>> mixture = train_data.transform(mixture, **batch["transform_args"]) |
| |
| """ |
|
|
| def __init__( |
| self, |
| loaders: Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]], |
| sample_rate: int, |
| n_examples: int = 1000, |
| duration: float = 0.5, |
| offset: float = None, |
| loudness_cutoff: float = -40, |
| num_channels: int = 1, |
| transform: Callable = None, |
| aligned: bool = False, |
| shuffle_loaders: bool = False, |
| matcher: Callable = default_matcher, |
| without_replacement: bool = True, |
| ): |
| |
| if isinstance(loaders, list): |
| loaders = {i: l for i, l in enumerate(loaders)} |
| elif isinstance(loaders, AudioLoader): |
| loaders = {0: loaders} |
|
|
| self.loaders = loaders |
| self.loudness_cutoff = loudness_cutoff |
| self.num_channels = num_channels |
|
|
| self.length = n_examples |
| self.transform = transform |
| self.sample_rate = sample_rate |
| self.duration = duration |
| self.offset = offset |
| self.aligned = aligned |
| self.shuffle_loaders = shuffle_loaders |
| self.without_replacement = without_replacement |
|
|
| if aligned: |
| loaders_list = list(loaders.values()) |
| for i in range(len(loaders_list[0].audio_lists)): |
| input_lists = [l.audio_lists[i] for l in loaders_list] |
| |
| align_lists(input_lists, matcher) |
|
|
| def __getitem__(self, idx): |
| state = util.random_state(idx) |
| offset = None if self.offset is None else self.offset |
| item = {} |
|
|
| keys = list(self.loaders.keys()) |
| if self.shuffle_loaders: |
| state.shuffle(keys) |
|
|
| loader_kwargs = { |
| "state": state, |
| "sample_rate": self.sample_rate, |
| "duration": self.duration, |
| "loudness_cutoff": self.loudness_cutoff, |
| "num_channels": self.num_channels, |
| "global_idx": idx if self.without_replacement else None, |
| } |
|
|
| |
| loader = self.loaders[keys[0]] |
| item[keys[0]] = loader(**loader_kwargs) |
|
|
| for key in keys[1:]: |
| loader = self.loaders[key] |
| if self.aligned: |
| |
| |
| offset = item[keys[0]]["signal"].metadata["offset"] |
| loader_kwargs.update( |
| { |
| "offset": offset, |
| "source_idx": item[keys[0]]["source_idx"], |
| "item_idx": item[keys[0]]["item_idx"], |
| } |
| ) |
| item[key] = loader(**loader_kwargs) |
|
|
| |
| keys = list(self.loaders.keys()) |
| item = {k: item[k] for k in keys} |
|
|
| item["idx"] = idx |
| if self.transform is not None: |
| item["transform_args"] = self.transform.instantiate( |
| state=state, signal=item[keys[0]]["signal"] |
| ) |
|
|
| |
| |
| |
| if len(keys) == 1: |
| item.update(item.pop(keys[0])) |
|
|
| return item |
|
|
| def __len__(self): |
| return self.length |
|
|
| @staticmethod |
| def collate(list_of_dicts: Union[list, dict], n_splits: int = None): |
| """Collates items drawn from this dataset. Uses |
| :py:func:`audiotools.core.util.collate`. |
| |
| Parameters |
| ---------- |
| list_of_dicts : typing.Union[list, dict] |
| Data drawn from each item. |
| n_splits : int |
| Number of splits to make when creating the batches (split into |
| sub-batches). Useful for things like gradient accumulation. |
| |
| Returns |
| ------- |
| dict |
| Dictionary of batched data. |
| """ |
| return util.collate(list_of_dicts, n_splits=n_splits) |
|
|
|
|
| class ConcatDataset(AudioDataset): |
| def __init__(self, datasets: list): |
| self.datasets = datasets |
|
|
| def __len__(self): |
| return sum([len(d) for d in self.datasets]) |
|
|
| def __getitem__(self, idx): |
| dataset = self.datasets[idx % len(self.datasets)] |
| return dataset[idx // len(self.datasets)] |
|
|
|
|
| class ResumableDistributedSampler(DistributedSampler): |
| """Distributed sampler that can be resumed from a given start index.""" |
|
|
| def __init__(self, dataset, start_idx: int = None, **kwargs): |
| super().__init__(dataset, **kwargs) |
| |
| self.start_idx = start_idx // self.num_replicas if start_idx is not None else 0 |
|
|
| def __iter__(self): |
| for i, idx in enumerate(super().__iter__()): |
| if i >= self.start_idx: |
| yield idx |
| self.start_idx = 0 |
|
|
|
|
| class ResumableSequentialSampler(SequentialSampler): |
| """Sequential sampler that can be resumed from a given start index.""" |
|
|
| def __init__(self, dataset, start_idx: int = None, **kwargs): |
| super().__init__(dataset, **kwargs) |
| |
| self.start_idx = start_idx if start_idx is not None else 0 |
|
|
| def __iter__(self): |
| for i, idx in enumerate(super().__iter__()): |
| if i >= self.start_idx: |
| yield idx |
| self.start_idx = 0 |
|
|