| | import os |
| | import pdb |
| |
|
| | import torch |
| | import numpy as np |
| | import pickle |
| | from tqdm import tqdm |
| | from transformers import Wav2Vec2Processor |
| | import librosa |
| | from collections import defaultdict |
| | from torch.utils import data |
| |
|
| |
|
| | class Dataset(data.Dataset): |
| | """Custom data.Dataset compatible with data.DataLoader.""" |
| | def __init__(self, data, subjects_dict, data_type="train", read_audio=False): |
| | self.data = data |
| | self.len = len(self.data) |
| | self.subjects_dict = subjects_dict |
| | self.data_type = data_type |
| | |
| | self.one_hot_labels = np.eye(len(subjects_dict["train"])) |
| | self.read_audio = read_audio |
| |
|
| | def __getitem__(self, index): |
| | """Returns one data pair (source and target).""" |
| | |
| | file_name = self.data[index]["name"] |
| | audio = self.data[index]["audio"] |
| | vertice = self.data[index]["vertice"] |
| | template = self.data[index]["template"] |
| | if self.data_type == "train": |
| | if len(self.one_hot_labels)==1: |
| | one_hot = self.one_hot_labels[0] |
| | else: |
| | |
| | subject = file_name.split("_")[0] |
| | one_hot = self.one_hot_labels[self.subjects_dict["train"].index(subject.capitalize())] |
| |
|
| | else: |
| | |
| | if len(self.one_hot_labels)==1: |
| | one_hot = self.one_hot_labels[0] |
| | else: |
| | subject = file_name.split("_")[0] |
| | one_hot = self.one_hot_labels[self.subjects_dict["train"].index(subject.capitalize())] |
| |
|
| | if self.read_audio: |
| | return torch.FloatTensor(audio), torch.FloatTensor(vertice), torch.FloatTensor(template), torch.FloatTensor(one_hot), file_name |
| | else: |
| | return torch.FloatTensor(vertice), torch.FloatTensor(template), torch.FloatTensor(one_hot), file_name |
| |
|
| | def __len__(self): |
| | return self.len |
| |
|
| | def read_data(args, test_config=False): |
| | print("Loading data...") |
| | data = defaultdict(dict) |
| | train_data = [] |
| | valid_data = [] |
| | test_data = [] |
| |
|
| | audio_path = os.path.join(args.data_root, args.wav_path) |
| | vertices_path = os.path.join(args.data_root, args.vertices_path) |
| | if args.read_audio: |
| | |
| | processor = Wav2Vec2Processor.from_pretrained(args.wav2vec2model_path) |
| |
|
| | template_file = os.path.join(args.data_root, args.template_file) |
| | with open(template_file, 'rb') as fin: |
| | templates = pickle.load(fin, encoding='latin1') |
| |
|
| | cnt=0 |
| |
|
| | |
| | train_txt = open(os.path.join(args.data_root,"train.txt"), "r") |
| | test_txt = open(os.path.join(args.data_root,"test.txt"), "r") |
| | train_lines, test_lines, train_list, test_list = train_txt.readlines(), test_txt.readlines(), [], [] |
| | for tt in train_lines: |
| | train_list.append(tt.split("\n")[0]) |
| | for tt in test_lines: |
| | test_list.append(tt.split("\n")[0]) |
| |
|
| | for r, ds, fs in os.walk(audio_path): |
| |
|
| | for f in tqdm(fs): |
| | |
| | if test_config and f not in test_list: |
| | continue |
| |
|
| | if f.endswith("wav"): |
| | if args.read_audio: |
| | wav_path = os.path.join(r, f) |
| | speech_array, sampling_rate = librosa.load(wav_path, sr=16000) |
| | input_values = np.squeeze(processor(speech_array, sampling_rate=16000).input_values) |
| | key = f.replace("wav", "npy") |
| | data[key]["audio"] = input_values if args.read_audio else None |
| | subject_id = "_".join(key.split("_")[:-1]) |
| | |
| | temp = templates["id"] |
| |
|
| | data[key]["name"] = f |
| | data[key]["template"] = temp.reshape((-1)) |
| |
|
| | vertice_path = os.path.join(vertices_path, f.replace("wav", "npz")) |
| |
|
| | if not os.path.exists(vertice_path): |
| | del data[key] |
| | else: |
| | if args.dataset == "vocaset": |
| | data[key]["vertice"] = np.load(vertice_path, allow_pickle=True)[::2, |
| | :] |
| | elif args.dataset == "BIWI": |
| | data[key]["vertice"] = np.load(vertice_path, allow_pickle=True) |
| | elif args.dataset=="multi": |
| | flame_param = np.load(vertice_path, allow_pickle=True) |
| | data[key]["vertice"] = flame_param["verts"].reshape((flame_param["verts"].shape[0], -1)) |
| |
|
| | subjects_dict = {} |
| | subjects_dict["train"] = [i for i in args.train_subjects.split(" ")] |
| | subjects_dict["val"] = [i for i in args.val_subjects.split(" ")] |
| | subjects_dict["test"] = [i for i in args.test_subjects.split(" ")] |
| |
|
| | |
| | train_cnt = 0 |
| | for k, v in data.items(): |
| | k_wav = k.replace("npy", "wav") |
| | if k_wav in train_list: |
| | if train_cnt<int(len(train_list)*0.9): |
| | train_data.append(v) |
| | else: |
| | valid_data.append(v) |
| | train_cnt+=1 |
| | elif k_wav in test_list: |
| | test_data.append(v) |
| |
|
| | print('Loaded data: Train-{}, Val-{}, Test-{}'.format(len(train_data), len(valid_data), len(test_data))) |
| | return train_data, valid_data, test_data, subjects_dict |
| |
|
| |
|
| | def get_dataloaders(args, test_config=False): |
| | dataset = {} |
| | train_data, valid_data, test_data, subjects_dict = read_data(args, test_config) |
| |
|
| | if not test_config: |
| | train_data = Dataset(train_data, subjects_dict, "train", args.read_audio) |
| | dataset["train"] = data.DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True, |
| | num_workers=args.workers) |
| | valid_data = Dataset(valid_data, subjects_dict, "val", args.read_audio) |
| | dataset["valid"] = data.DataLoader(dataset=valid_data, batch_size=1, shuffle=False, num_workers=args.workers) |
| | test_data = Dataset(test_data, subjects_dict, "test", args.read_audio) |
| | dataset["test"] = data.DataLoader(dataset=test_data, batch_size=1, shuffle=True, num_workers=args.workers) |
| | return dataset |
| |
|
| |
|
| | if __name__ == "__main__": |
| | get_dataloaders() |