| | import torch |
| | import librosa |
| | import numpy as np |
| | import random |
| | import os |
| | from torch.utils.data import DataLoader |
| | from modules.audio import mel_spectrogram |
| |
|
| |
|
| | duration_setting = { |
| | "min": 1.0, |
| | "max": 30.0, |
| | } |
| | |
| | def to_mel_fn(wave, mel_fn_args): |
| | return mel_spectrogram(wave, **mel_fn_args) |
| |
|
| | class FT_Dataset(torch.utils.data.Dataset): |
| | def __init__( |
| | self, |
| | data_path, |
| | spect_params, |
| | sr=22050, |
| | batch_size=1, |
| | ): |
| | self.data_path = data_path |
| | self.data = [] |
| | for root, _, files in os.walk(data_path): |
| | for file in files: |
| | if file.endswith((".wav", ".mp3", ".flac", ".ogg", ".m4a", ".opus")): |
| | self.data.append(os.path.join(root, file)) |
| |
|
| | self.sr = sr |
| | self.mel_fn_args = { |
| | "n_fft": spect_params['n_fft'], |
| | "win_size": spect_params['win_length'], |
| | "hop_size": spect_params['hop_length'], |
| | "num_mels": spect_params['n_mels'], |
| | "sampling_rate": sr, |
| | "fmin": spect_params['fmin'], |
| | "fmax": None if spect_params['fmax'] == "None" else spect_params['fmax'], |
| | "center": False |
| | } |
| |
|
| | assert len(self.data) != 0 |
| | while len(self.data) < batch_size: |
| | self.data += self.data |
| |
|
| | def __len__(self): |
| | return len(self.data) |
| |
|
| | def __getitem__(self, idx): |
| | idx = idx % len(self.data) |
| | wav_path = self.data[idx] |
| | try: |
| | speech, orig_sr = librosa.load(wav_path, sr=self.sr) |
| | except Exception as e: |
| | print(f"Failed to load wav file with error {e}") |
| | return self.__getitem__(random.randint(0, len(self))) |
| | if len(speech) < self.sr * duration_setting["min"] or len(speech) > self.sr * duration_setting["max"]: |
| | print(f"Audio {wav_path} is too short or too long, skipping") |
| | return self.__getitem__(random.randint(0, len(self))) |
| | if orig_sr != self.sr: |
| | speech = librosa.resample(speech, orig_sr, self.sr) |
| |
|
| | wave = torch.from_numpy(speech).float().unsqueeze(0) |
| | mel = to_mel_fn(wave, self.mel_fn_args).squeeze(0) |
| |
|
| | return wave.squeeze(0), mel |
| |
|
| |
|
| | def build_ft_dataloader(data_path, spect_params, sr, batch_size=1, num_workers=0): |
| | dataset = FT_Dataset(data_path, spect_params, sr, batch_size) |
| | dataloader = torch.utils.data.DataLoader( |
| | dataset, |
| | batch_size=batch_size, |
| | shuffle=True, |
| | num_workers=num_workers, |
| | collate_fn=collate, |
| | ) |
| | return dataloader |
| |
|
| | def collate(batch): |
| | batch_size = len(batch) |
| |
|
| | |
| | lengths = [b[1].shape[1] for b in batch] |
| | batch_indexes = np.argsort(lengths)[::-1] |
| | batch = [batch[bid] for bid in batch_indexes] |
| |
|
| | nmels = batch[0][1].size(0) |
| | max_mel_length = max([b[1].shape[1] for b in batch]) |
| | max_wave_length = max([b[0].size(0) for b in batch]) |
| |
|
| | mels = torch.zeros((batch_size, nmels, max_mel_length)).float() - 10 |
| | waves = torch.zeros((batch_size, max_wave_length)).float() |
| |
|
| | mel_lengths = torch.zeros(batch_size).long() |
| | wave_lengths = torch.zeros(batch_size).long() |
| |
|
| | for bid, (wave, mel) in enumerate(batch): |
| | mel_size = mel.size(1) |
| | mels[bid, :, :mel_size] = mel |
| | waves[bid, : wave.size(0)] = wave |
| | mel_lengths[bid] = mel_size |
| | wave_lengths[bid] = wave.size(0) |
| |
|
| | return waves, mels, wave_lengths, mel_lengths |
| |
|
| | if __name__ == "__main__": |
| | data_path = "./example/reference" |
| | sr = 22050 |
| | spect_params = { |
| | "n_fft": 1024, |
| | "win_length": 1024, |
| | "hop_length": 256, |
| | "n_mels": 80, |
| | "fmin": 0, |
| | "fmax": 8000, |
| | } |
| | dataloader = build_ft_dataloader(data_path, spect_params, sr, batch_size=2, num_workers=0) |
| | for idx, batch in enumerate(dataloader): |
| | wave, mel, wave_lengths, mel_lengths = batch |
| | print(wave.shape, mel.shape) |
| | if idx == 10: |
| | break |
| |
|