Spaces:
Running on Zero
Running on Zero
| import random | |
| import pickle | |
| import torch | |
| import torchaudio | |
| import torch.nn.functional as F | |
| import torchaudio.compliance.kaldi as Kaldi | |
| from speakerlab.process.augmentation import NoiseReverbCorrupter | |
| from speakerlab.utils.fileio import load_data_csv | |
| class WavReader(object): | |
| def __init__(self, | |
| sample_rate = 16000, | |
| duration: float = 3.0, | |
| speed_pertub: bool = False, | |
| lm: bool = True, | |
| ): | |
| self.duration = duration | |
| self.sample_rate = sample_rate | |
| self.speed_pertub = speed_pertub | |
| self.lm = lm | |
| def __call__(self, wav_path): | |
| wav, sr = torchaudio.load(wav_path) | |
| assert sr == self.sample_rate | |
| wav = wav[0] | |
| if self.speed_pertub and self.lm: | |
| speeds = [1.0, 0.9, 1.1] | |
| speed_idx = random.randint(0, 2) | |
| if speed_idx > 0: | |
| wav, _ = torchaudio.sox_effects.apply_effects_tensor( | |
| wav.unsqueeze(0), self.sample_rate, [['speed', str(speeds[speed_idx])], ['rate', str(self.sample_rate)]]) | |
| else: | |
| speed_idx = 0 | |
| wav = wav.squeeze(0) | |
| data_len = wav.shape[0] | |
| chunk_len = int(self.duration * sr) | |
| if data_len >= chunk_len: | |
| start = random.randint(0, data_len - chunk_len) | |
| end = start + chunk_len | |
| wav = wav[start:end] | |
| else: | |
| wav = F.pad(wav, (0, chunk_len - data_len)) | |
| return wav, speed_idx | |
| class SpkLabelEncoder(object): | |
| def __init__(self, data_file): | |
| self.lab2ind = {} | |
| self.ind2lab = {} | |
| self.starting_index = -1 | |
| self.load_from_csv(data_file) | |
| def __call__(self, spk, speed_idx=0): | |
| spkid = self.lab2ind[spk] | |
| spkid = spkid + len(self.lab2ind) * speed_idx | |
| return spkid | |
| def load_from_csv(self, path): | |
| self.data = load_data_csv(path) | |
| for key in self.data: | |
| self.add(self.data[key]['spk']) | |
| def add(self, label): | |
| if label in self.lab2ind: | |
| return | |
| index = self._next_index() | |
| self.lab2ind[label] = index | |
| self.ind2lab[index] = label | |
| def _next_index(self): | |
| self.starting_index += 1 | |
| return self.starting_index | |
| def __len__(self): | |
| return len(self.lab2ind) | |
| def save(self, path, device=None): | |
| with open(path, 'wb') as f: | |
| pickle.dump(self.lab2ind, f) | |
| def load(self, path, device=None): | |
| self.lab2ind = {} | |
| self.ind2lab = {} | |
| with open(path, 'rb') as f: | |
| self.lab2ind = pickle.load(f) | |
| for label in self.lab2ind: | |
| self.ind2lab[self.lab2ind[label]] = label | |
| class SpkVeriAug(object): | |
| def __init__( | |
| self, | |
| aug_prob: float = 0.0, | |
| noise_file: str = None, | |
| reverb_file: str = None, | |
| ): | |
| self.aug_prob = aug_prob | |
| if aug_prob > 0: | |
| self.add_noise = NoiseReverbCorrupter( | |
| noise_prob=1.0, | |
| noise_file=noise_file, | |
| ) | |
| self.add_rir = NoiseReverbCorrupter( | |
| reverb_prob=1.0, | |
| reverb_file=reverb_file, | |
| ) | |
| self.add_rir_noise = NoiseReverbCorrupter( | |
| noise_prob=1.0, | |
| reverb_prob=1.0, | |
| noise_file=noise_file, | |
| reverb_file=reverb_file, | |
| ) | |
| self.augmentations = [self.add_noise, self.add_rir, self.add_rir_noise] | |
| def __call__(self, wav): | |
| sample_rate = 16000 | |
| if self.aug_prob > random.random(): | |
| aug = random.choice(self.augmentations) | |
| wav = aug(wav, sample_rate) | |
| return wav | |
| class FBank(object): | |
| def __init__(self, | |
| n_mels, | |
| sample_rate, | |
| mean_nor: bool = False, | |
| ): | |
| self.n_mels = n_mels | |
| self.sample_rate = sample_rate | |
| self.mean_nor = mean_nor | |
| def __call__(self, wav, dither=0): | |
| sr = 16000 | |
| assert sr==self.sample_rate | |
| if len(wav.shape) == 1: | |
| wav = wav.unsqueeze(0) | |
| # select single channel | |
| if wav.shape[0] > 1: | |
| wav = wav[0, :] | |
| wav = wav.unsqueeze(0) | |
| assert len(wav.shape) == 2 and wav.shape[0]==1 | |
| feat = Kaldi.fbank(wav, num_mel_bins=self.n_mels, | |
| sample_frequency=sr, dither=dither) | |
| # feat: [T, N] | |
| if self.mean_nor: | |
| feat = feat - feat.mean(0, keepdim=True) | |
| return feat | |