Spaces:
Running on Zero
Running on Zero
File size: 4,472 Bytes
03022ee | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | import random
import pickle
import torch
import torchaudio
import torch.nn.functional as F
import torchaudio.compliance.kaldi as Kaldi
from local.process.augmentation import NoiseReverbCorrupter
from local.utils.fileio import load_data_csv
class WavReader(object):
def __init__(self,
sample_rate = 16000,
duration: float = 3.0,
speed_pertub: bool = False,
lm: bool = True,
):
self.duration = duration
self.sample_rate = sample_rate
self.speed_pertub = speed_pertub
self.lm = lm
def __call__(self, wav_path):
wav, sr = torchaudio.load(wav_path)
assert sr == self.sample_rate
wav = wav[0]
if self.speed_pertub and self.lm:
speeds = [1.0, 0.9, 1.1]
speed_idx = random.randint(0, 2)
if speed_idx > 0:
wav, _ = torchaudio.sox_effects.apply_effects_tensor(
wav.unsqueeze(0), self.sample_rate, [['speed', str(speeds[speed_idx])], ['rate', str(self.sample_rate)]])
else:
speed_idx = 0
wav = wav.squeeze(0)
data_len = wav.shape[0]
chunk_len = int(self.duration * sr)
if data_len >= chunk_len:
start = random.randint(0, data_len - chunk_len)
end = start + chunk_len
wav = wav[start:end]
else:
wav = F.pad(wav, (0, chunk_len - data_len))
return wav, speed_idx
class SpkLabelEncoder(object):
def __init__(self, data_file):
self.lab2ind = {}
self.ind2lab = {}
self.starting_index = -1
self.load_from_csv(data_file)
def __call__(self, spk, speed_idx=0):
spkid = self.lab2ind[spk]
spkid = spkid + len(self.lab2ind) * speed_idx
return spkid
def load_from_csv(self, path):
self.data = load_data_csv(path)
for key in self.data:
self.add(self.data[key]['spk'])
def add(self, label):
if label in self.lab2ind:
return
index = self._next_index()
self.lab2ind[label] = index
self.ind2lab[index] = label
def _next_index(self):
self.starting_index += 1
return self.starting_index
def __len__(self):
return len(self.lab2ind)
def save(self, path, device=None):
with open(path, 'wb') as f:
pickle.dump(self.lab2ind, f)
def load(self, path, device=None):
self.lab2ind = {}
self.ind2lab = {}
with open(path, 'rb') as f:
self.lab2ind = pickle.load(f)
for label in self.lab2ind:
self.ind2lab[self.lab2ind[label]] = label
class SpkVeriAug(object):
def __init__(
self,
aug_prob: float = 0.0,
noise_file: str = None,
reverb_file: str = None,
):
self.aug_prob = aug_prob
if aug_prob > 0:
self.add_noise = NoiseReverbCorrupter(
noise_prob=1.0,
noise_file=noise_file,
)
self.add_rir = NoiseReverbCorrupter(
reverb_prob=1.0,
reverb_file=reverb_file,
)
self.add_rir_noise = NoiseReverbCorrupter(
noise_prob=1.0,
reverb_prob=1.0,
noise_file=noise_file,
reverb_file=reverb_file,
)
self.augmentations = [self.add_noise, self.add_rir, self.add_rir_noise]
def __call__(self, wav):
sample_rate = 16000
if self.aug_prob > random.random():
aug = random.choice(self.augmentations)
wav = aug(wav, sample_rate)
return wav
class FBank(object):
def __init__(self,
n_mels,
sample_rate,
mean_nor: bool = False,
):
self.n_mels = n_mels
self.sample_rate = sample_rate
self.mean_nor = mean_nor
def __call__(self, wav, dither=0):
sr = 16000
assert sr==self.sample_rate
if len(wav.shape) == 1:
wav = wav.unsqueeze(0)
# select single channel
if wav.shape[0] > 1:
wav = wav[0, :]
wav = wav.unsqueeze(0)
assert len(wav.shape) == 2 and wav.shape[0]==1
feat = Kaldi.fbank(wav, num_mel_bins=self.n_mels,
sample_frequency=sr, dither=dither)
# feat: [T, N]
if self.mean_nor:
feat = feat - feat.mean(0, keepdim=True)
return feat
|