| | |
| | |
| | ''' |
| | @Project :Waveformer-main |
| | @File :dataset_online.py |
| | @IDE :PyCharm |
| | @Author :Aisaka/Hao Ma @SDU |
| | @Date :2023/11/1 下午6:47 |
| | ''' |
| | import os |
| | import random |
| |
|
| | import torch |
| | import torchaudio |
| | import torchaudio.transforms as AT |
| | import csv |
| | import json |
| | import numpy as np |
| | import librosa |
| |
|
| |
|
| | def labels2caption(labels): |
| | prefix = "The sound of " if len(labels) == 1 else "The sounds of " |
| | caption = prefix + ', '.join(labels) |
| | return caption |
| |
|
| |
|
| | class CLAPSepDataSet(torch.utils.data.Dataset): |
| |
|
| | def __init__(self, data_list, dset='', silence_rate=0.05, chunk_dur=10, sr=None, resample_rate=None): |
| | assert dset in ['train', 'val'], \ |
| | "`dset` must be one of ['train', 'val']" |
| | self.dset = dset |
| | self.silence_rate = silence_rate |
| | self.chunk_dur = chunk_dur |
| | self.data_meta = dict() |
| | self.text_dict = dict() |
| | with open(data_list, 'r', encoding='utf-8') as d: |
| | reader = csv.reader(d, skipinitialspace=True) |
| | for row in reader: |
| | assert os.path.exists(row[0]) |
| | self.data_meta[row[0]] = row[1:] |
| | label = ', '.join(row[1:]) |
| | if label not in self.text_dict: |
| | self.text_dict[label] = [] |
| | self.text_dict[label].append(row[0]) |
| | |
| | self.augmentation = torchaudio.transforms.SpeedPerturbation(48000, [0.9, 1.1]) |
| |
|
| | self.data_names = list(self.data_meta.keys()) |
| | if dset == 'val': |
| | self.noise_names = [] |
| | for name in self.data_names: |
| | noise_name = self.choose_other_samples(', '.join(self.data_meta[name]), 1)[0] |
| | self.noise_names.append(noise_name) |
| |
|
| | if resample_rate is not None: |
| | self.resampler = AT.Resample(sr, resample_rate) |
| | self.sr = sr |
| | self.resample_rate = resample_rate |
| | else: |
| | self.sr = sr |
| |
|
| | def __len__(self): |
| | return len(self.data_names) |
| |
|
| | def choose_other_samples(self, target_text, num): |
| | candidates = list(self.text_dict.keys()) |
| | candidates.remove(target_text) |
| | chosen_text = random.sample(candidates, num) |
| | chosen_samples = [random.choice(self.text_dict[text]) for text in chosen_text] |
| | return chosen_samples |
| |
|
| | def load_wav(self, path): |
| | max_length = self.sr * self.chunk_dur |
| | wav = librosa.core.load(path, sr=self.sr)[0] |
| | if len(wav) > max_length: |
| | wav = wav[0:max_length] |
| |
|
| | |
| | if len(wav) < max_length: |
| | wav = np.pad(wav, (0, max_length - len(wav)), 'constant') |
| | return wav |
| |
|
| | def __getitem__(self, idx): |
| | tgt_name = self.data_names[idx] |
| | if self.dset =='train': |
| | noise_name = tgt_name |
| | while set(self.data_meta[noise_name]) & set(self.data_meta[tgt_name]): |
| | noise_name = random.choice(self.data_names) |
| | else: |
| | noise_name = self.noise_names[idx] |
| |
|
| | snr = torch.zeros((1,)) |
| | |
| | tgt = torch.tensor(self.load_wav(tgt_name)).unsqueeze(0) |
| | noise = torch.tensor(self.load_wav(noise_name)).unsqueeze(0) |
| | |
| | |
| | mixed = torchaudio.functional.add_noise(tgt, noise, snr=snr) |
| | assert not torch.isnan(mixed).any(), f"tgt: {tgt_name}, noise: {noise_name}" |
| | pos_sample, _ = self.augmentation(self.resampler(tgt.squeeze())) |
| | neg_sample, _ = self.augmentation(self.resampler(noise.squeeze())) |
| |
|
| | max_value = torch.max(torch.abs(mixed)) |
| | if max_value > 1: |
| | tgt *= 0.9 / max_value |
| | mixed *= 0.9 / max_value |
| |
|
| | tgt = tgt.squeeze() |
| | mixed = mixed.squeeze() |
| | tgt_cap = labels2caption(self.data_meta[tgt_name]) |
| | neg_cap = labels2caption(self.data_meta[noise_name]) |
| | mixed_resample = self.resampler(mixed) |
| | |
| | |
| | if self.dset =='train' and random.random() < self.silence_rate: |
| | other_name = tgt_name |
| | while set(self.data_meta[other_name]) & (set(self.data_meta[tgt_name]) | set(self.data_meta[noise_name])): |
| | other_name = random.choice(self.data_names) |
| | tgt = torch.zeros_like(mixed) |
| | neg_cap = labels2caption(self.data_meta[tgt_name] + self.data_meta[noise_name]) |
| | tgt_cap = labels2caption(self.data_meta[other_name]) |
| | pos_sample, _ = self.augmentation(self.resampler(torch.tensor(self.load_wav(other_name)))) |
| | neg_sample, _ = self.augmentation(mixed_resample) |
| |
|
| | return mixed, mixed_resample, tgt_cap, neg_cap, tgt, self.pad_or_trim(pos_sample), self.pad_or_trim(neg_sample) |
| |
|
| | def pad_or_trim(self, wav_in): |
| | target_len = 48000 * self.chunk_dur |
| | if wav_in.size(0) < target_len: |
| | wav_in = torch.nn.functional.pad(wav_in, (0, target_len - wav_in.size(0))) |
| | elif wav_in.size(0) > target_len: |
| | wav_in = wav_in[:target_len] |
| | max_value = torch.max(torch.abs(wav_in)) |
| | if max_value > 1: |
| | wav_in *= 0.9 / max_value |
| | return wav_in |
| |
|
| |
|
| | class CLAPSepDataEngineDataSet(torch.utils.data.Dataset): |
| |
|
| | def __init__(self, data_list, dset='', data_engine_json='', silence_rate=0.05, chunk_dur=10, sr=None, resample_rate=None): |
| | assert dset in ['train', 'val'], \ |
| | "`dset` must be one of ['train', 'val']" |
| | self.dset = dset |
| | self.silence_rate = silence_rate |
| | self.chunk_dur = chunk_dur |
| | self.data_meta = dict() |
| | with open(data_list, 'r', encoding='utf-8') as d: |
| | reader = csv.reader(d, skipinitialspace=True) |
| | for row in reader: |
| | assert os.path.exists(row[0]), row[0] |
| | self.data_meta[row[0]] = row[1:] |
| | |
| | self.augmentation = torchaudio.transforms.SpeedPerturbation(48000, [0.9, 1.1]) |
| |
|
| | self.data_names = list(self.data_meta.keys()) |
| | if dset == 'val': |
| | self.noise_names = [] |
| | for name in self.data_names: |
| | noise_name = name |
| | while set(self.data_meta[noise_name]) & set(self.data_meta[name]): |
| | noise_name = random.choice(self.data_names) |
| | self.noise_names.append(noise_name) |
| | |
| | self.data_engine_dict = {} |
| | if os.path.exists(data_engine_json): |
| | self.data_engine_dict = json.load(open(data_engine_json, 'r')) |
| |
|
| | if resample_rate is not None: |
| | self.resampler = AT.Resample(sr, resample_rate) |
| | self.sr = sr |
| | self.resample_rate = resample_rate |
| | else: |
| | self.sr = sr |
| |
|
| | def __len__(self): |
| | return len(self.data_names) |
| |
|
| | def load_wav(self, path): |
| | max_length = self.sr * self.chunk_dur |
| | wav = librosa.core.load(path, sr=self.sr)[0] |
| | if len(wav) > max_length: |
| | wav = wav[0:max_length] |
| |
|
| | |
| | if len(wav) < max_length: |
| | wav = np.pad(wav, (0, max_length - len(wav)), 'constant') |
| | return wav |
| |
|
| | def __getitem__(self, idx): |
| | tgt_name = self.data_names[idx] |
| | if self.dset =='train': |
| | noise_name = tgt_name |
| | while set(self.data_meta[noise_name]) & set(self.data_meta[tgt_name]): |
| | noise_name = random.choice(self.data_names) |
| | else: |
| | noise_name = self.noise_names[idx] |
| | |
| | snr = torch.zeros((1,)) |
| | |
| | tgt = torch.tensor(self.load_wav(tgt_name)).unsqueeze(0) |
| | noise = torch.tensor(self.load_wav(noise_name)).unsqueeze(0) |
| | |
| | |
| | mixed = torchaudio.functional.add_noise(tgt, noise, snr=snr) |
| | |
| | |
| | pos_sample, _ = self.augmentation(self.resampler(tgt.squeeze())) |
| | noise = noise.squeeze() |
| | |
| | max_value = torch.max(torch.abs(mixed)) |
| | if max_value > 1: |
| | tgt *= 0.9 / max_value |
| | mixed *= 0.9 / max_value |
| | |
| | tgt = tgt.squeeze() |
| | mixed = mixed.squeeze() |
| | tgt_cap = labels2caption(self.data_meta[tgt_name]) |
| | neg_cap = labels2caption(self.data_meta[noise_name]) |
| | mixed_resample = self.resampler(mixed) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | video = tgt_name.split('/')[-1][:-4] |
| | if self.dset =='train' and video in self.data_engine_dict and random.random() > 0.5: |
| | mixed = tgt |
| | mixed_resample = self.resampler(mixed) |
| | items = self.data_engine_dict[video] |
| | tgt_idx = random.choice(range(0, len(items))) |
| | tgt_item = items[tgt_idx] |
| | items.pop(tgt_idx) |
| | tgt = torch.tensor(self.load_wav(tgt_item[0])) |
| | max_value = torch.max(torch.abs(tgt)) |
| | if max_value > 1: |
| | tgt *= 0.9 / max_value |
| | tgt_cap = tgt_item[1] |
| | if len(items) > 0: |
| | noises = [torch.tensor(self.load_wav(x[0])) for x in items] |
| | noise_caps = [x[1] for x in items] |
| | noise = torch.mean(torch.stack(noises, dim=0), dim=0) |
| | neg_cap = labels2caption(noise_caps) |
| | |
| | |
| | elif self.dset =='train' and random.random() < self.silence_rate: |
| | other_name = tgt_name |
| | while set(self.data_meta[other_name]) & (set(self.data_meta[tgt_name]) | set(self.data_meta[noise_name])): |
| | other_name = random.choice(self.data_names) |
| | tgt = torch.zeros_like(mixed) |
| | neg_cap = labels2caption(self.data_meta[tgt_name] + self.data_meta[noise_name]) |
| | tgt_cap = labels2caption(self.data_meta[other_name]) |
| | pos_sample, _ = self.augmentation(self.resampler(torch.tensor(self.load_wav(other_name)))) |
| | noise = mixed |
| | |
| | neg_sample, _ = self.augmentation(self.resampler(noise)) |
| |
|
| | return mixed, mixed_resample, tgt_cap, neg_cap, tgt, self.pad_or_trim(pos_sample), self.pad_or_trim(neg_sample) |
| |
|
| | def pad_or_trim(self, wav_in): |
| | target_len = 48000 * self.chunk_dur |
| | if wav_in.size(0) < target_len: |
| | wav_in = torch.nn.functional.pad(wav_in, (0, target_len - wav_in.size(0))) |
| | elif wav_in.size(0) > target_len: |
| | wav_in = wav_in[:target_len] |
| | max_value = torch.max(torch.abs(wav_in)) |
| | if max_value > 1: |
| | wav_in *= 0.9 / max_value |
| | return wav_in |
| |
|
| |
|