xuan3986's picture
Upload 111 files
03022ee verified
import random
import pickle
import torch
import torchaudio
import torch.nn.functional as F
import torchaudio.compliance.kaldi as Kaldi
from local.process.augmentation import NoiseReverbCorrupter
from local.utils.fileio import load_data_csv
class WavReader(object):
def __init__(self,
sample_rate = 16000,
duration: float = 3.0,
speed_pertub: bool = False,
lm: bool = True,
):
self.duration = duration
self.sample_rate = sample_rate
self.speed_pertub = speed_pertub
self.lm = lm
def __call__(self, wav_path):
wav, sr = torchaudio.load(wav_path)
assert sr == self.sample_rate
wav = wav[0]
if self.speed_pertub and self.lm:
speeds = [1.0, 0.9, 1.1]
speed_idx = random.randint(0, 2)
if speed_idx > 0:
wav, _ = torchaudio.sox_effects.apply_effects_tensor(
wav.unsqueeze(0), self.sample_rate, [['speed', str(speeds[speed_idx])], ['rate', str(self.sample_rate)]])
else:
speed_idx = 0
wav = wav.squeeze(0)
data_len = wav.shape[0]
chunk_len = int(self.duration * sr)
if data_len >= chunk_len:
start = random.randint(0, data_len - chunk_len)
end = start + chunk_len
wav = wav[start:end]
else:
wav = F.pad(wav, (0, chunk_len - data_len))
return wav, speed_idx
class SpkLabelEncoder(object):
def __init__(self, data_file):
self.lab2ind = {}
self.ind2lab = {}
self.starting_index = -1
self.load_from_csv(data_file)
def __call__(self, spk, speed_idx=0):
spkid = self.lab2ind[spk]
spkid = spkid + len(self.lab2ind) * speed_idx
return spkid
def load_from_csv(self, path):
self.data = load_data_csv(path)
for key in self.data:
self.add(self.data[key]['spk'])
def add(self, label):
if label in self.lab2ind:
return
index = self._next_index()
self.lab2ind[label] = index
self.ind2lab[index] = label
def _next_index(self):
self.starting_index += 1
return self.starting_index
def __len__(self):
return len(self.lab2ind)
def save(self, path, device=None):
with open(path, 'wb') as f:
pickle.dump(self.lab2ind, f)
def load(self, path, device=None):
self.lab2ind = {}
self.ind2lab = {}
with open(path, 'rb') as f:
self.lab2ind = pickle.load(f)
for label in self.lab2ind:
self.ind2lab[self.lab2ind[label]] = label
class SpkVeriAug(object):
def __init__(
self,
aug_prob: float = 0.0,
noise_file: str = None,
reverb_file: str = None,
):
self.aug_prob = aug_prob
if aug_prob > 0:
self.add_noise = NoiseReverbCorrupter(
noise_prob=1.0,
noise_file=noise_file,
)
self.add_rir = NoiseReverbCorrupter(
reverb_prob=1.0,
reverb_file=reverb_file,
)
self.add_rir_noise = NoiseReverbCorrupter(
noise_prob=1.0,
reverb_prob=1.0,
noise_file=noise_file,
reverb_file=reverb_file,
)
self.augmentations = [self.add_noise, self.add_rir, self.add_rir_noise]
def __call__(self, wav):
sample_rate = 16000
if self.aug_prob > random.random():
aug = random.choice(self.augmentations)
wav = aug(wav, sample_rate)
return wav
class FBank(object):
def __init__(self,
n_mels,
sample_rate,
mean_nor: bool = False,
):
self.n_mels = n_mels
self.sample_rate = sample_rate
self.mean_nor = mean_nor
def __call__(self, wav, dither=0):
sr = 16000
assert sr==self.sample_rate
if len(wav.shape) == 1:
wav = wav.unsqueeze(0)
# select single channel
if wav.shape[0] > 1:
wav = wav[0, :]
wav = wav.unsqueeze(0)
assert len(wav.shape) == 2 and wav.shape[0]==1
feat = Kaldi.fbank(wav, num_mel_bins=self.n_mels,
sample_frequency=sr, dither=dither)
# feat: [T, N]
if self.mean_nor:
feat = feat - feat.mean(0, keepdim=True)
return feat