| | import os |
| | import sys |
| | import time |
| | import librosa |
| | import yaml |
| | import joblib |
| | import argparse |
| |
|
| | import soundfile as sf |
| | import numpy as np |
| |
|
| | from pathlib import Path |
| | from collections import defaultdict |
| | from typing import Optional |
| | from tqdm import tqdm |
| |
|
| | sys.path.append(os.path.dirname(__file__)) |
| | sys.path.append(os.path.dirname(os.path.dirname(__file__))) |
| |
|
| | |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch.distributed as dist |
| | from torch.nn.parallel import DistributedDataParallel |
| |
|
| | |
| | from nnet.WavLM import WavLM, WavLMConfig |
| |
|
| | |
| | from vq.codec_encoder import CodecEncoder_Transformer |
| | from vq.codec_decoder_vocos import CodecDecoderVocos |
| | from vq.module import SemanticEncoder |
| | from transformers import AutoFeatureExtractor, Wav2Vec2BertModel |
| | from collections import OrderedDict |
| |
|
| | |
| | from loader.datareader import DataReader |
| | from loader.datareader_aec import DataReaderAEC |
| | from loader.datareader_tse import DataReaderTSE |
| |
|
| | |
| | from nnet.llase import LLM_AR as model |
| |
|
| | class Encodec(): |
| | ''' |
| | Load Xcodec2 |
| | ''' |
| | def __init__(self,device="cpu") -> None: |
| | self.device=device |
| | ckpt = "./ckpt/codec_ckpt/epoch=4-step=1400000.ckpt", |
| | |
| | ckpt = torch.load(ckpt, map_location='cpu') |
| | state_dict = ckpt['state_dict'] |
| | filtered_state_dict_codec = OrderedDict() |
| | filtered_state_dict_semantic_encoder = OrderedDict() |
| | filtered_state_dict_gen = OrderedDict() |
| | filtered_state_dict_fc_post_a = OrderedDict() |
| | filtered_state_dict_fc_prior = OrderedDict() |
| | for key, value in state_dict.items(): |
| | if key.startswith('CodecEnc.'): |
| | new_key = key[len('CodecEnc.'):] |
| | filtered_state_dict_codec[new_key] = value |
| | elif key.startswith('generator.'): |
| | new_key = key[len('generator.'):] |
| | filtered_state_dict_gen[new_key] = value |
| | elif key.startswith('fc_post_a.'): |
| | new_key = key[len('fc_post_a.'):] |
| | filtered_state_dict_fc_post_a[new_key] = value |
| | elif key.startswith('SemanticEncoder_module.'): |
| | new_key = key[len('SemanticEncoder_module.'):] |
| | filtered_state_dict_semantic_encoder[new_key] = value |
| | elif key.startswith('fc_prior.'): |
| | new_key = key[len('fc_prior.'):] |
| | filtered_state_dict_fc_prior[new_key] = value |
| | |
| | self.semantic_model = Wav2Vec2BertModel.from_pretrained( |
| | "./ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0", |
| | |
| | output_hidden_states=True) |
| | self.semantic_model=self.semantic_model.eval().to(self.device) |
| | |
| | self.SemanticEncoder_module = SemanticEncoder(1024,1024,1024) |
| | self.SemanticEncoder_module.load_state_dict(filtered_state_dict_semantic_encoder) |
| | self.SemanticEncoder_module = self.SemanticEncoder_module.eval().to(self.device) |
| |
|
| | self.encoder = CodecEncoder_Transformer() |
| | self.encoder.load_state_dict(filtered_state_dict_codec) |
| | self.encoder = self.encoder.eval().to(self.device) |
| |
|
| | self.decoder = CodecDecoderVocos() |
| | self.decoder.load_state_dict(filtered_state_dict_gen) |
| | self.decoder = self.decoder.eval().to(self.device) |
| |
|
| | self.fc_post_a = nn.Linear( 2048, 1024 ) |
| | self.fc_post_a.load_state_dict(filtered_state_dict_fc_post_a) |
| | self.fc_post_a = self.fc_post_a.eval().to(self.device) |
| |
|
| | self.fc_prior = nn.Linear( 2048, 2048 ) |
| | self.fc_prior.load_state_dict(filtered_state_dict_fc_prior) |
| | self.fc_prior = self.fc_prior.eval().to(self.device) |
| |
|
| | self.feature_extractor = AutoFeatureExtractor.from_pretrained( |
| | "./ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0") |
| | |
| | |
| | def get_feat(self, wav_batch, pad=None): |
| |
|
| | if len(wav_batch.shape) != 2: |
| | return self.feature_extractor(F.pad(wav_batch, pad), sampling_rate=16000, return_tensors="pt") .data['input_features'] |
| | |
| | padded_wavs = torch.stack([F.pad(wav, pad) for wav in wav_batch]) |
| | batch_feats = [] |
| |
|
| | for wav in padded_wavs: |
| | feat = self.feature_extractor( |
| | wav, |
| | sampling_rate=16000, |
| | return_tensors="pt" |
| | ).data['input_features'] |
| |
|
| | batch_feats.append(feat) |
| | feat_batch = torch.concat(batch_feats, dim=0).to(self.device) |
| | return feat_batch |
| |
|
| | def get_embedding(self, wav_cpu): |
| | wav_cpu = wav_cpu.cpu() |
| | feat = self.get_feat(wav_cpu,pad=(160,160)) |
| | feat = feat.to(self.device) |
| |
|
| | if(len(wav_cpu.shape)==1): |
| | wav = wav_cpu.unsqueeze(0).to(self.device) |
| | else: |
| | wav = wav_cpu.to(self.device) |
| |
|
| | wav = torch.nn.functional.pad(wav, (0, (200 - (wav.shape[1] % 200)))) |
| | with torch.no_grad(): |
| | vq_emb = self.encoder(wav.unsqueeze(1)) |
| | vq_emb = vq_emb.transpose(1, 2) |
| |
|
| | if vq_emb.shape[2]!=feat.shape[1]: |
| | feat = self.get_feat(wav_cpu) |
| | feat = feat.to(self.device) |
| |
|
| | semantic_target = self.semantic_model(feat[:, :,:]) |
| | semantic_target = semantic_target.hidden_states[16] |
| | semantic_target = semantic_target.transpose(1, 2) |
| | semantic_target = self.SemanticEncoder_module(semantic_target) |
| |
|
| | vq_emb = torch.cat([semantic_target, vq_emb], dim=1) |
| |
|
| | return vq_emb |
| | |
| | def emb2token(self, emb): |
| | emb.to(self.device) |
| | emb = self.fc_prior(emb.transpose(1, 2)).transpose(1, 2) |
| | _, vq_code, _ = self.decoder(emb, vq=True) |
| | return vq_code |
| |
|
| | def token2wav(self, vq_code): |
| | vq_code.to(self.device) |
| | vq_post_emb = self.decoder.quantizer.get_output_from_indices(vq_code.transpose(1, 2)) |
| | vq_post_emb = vq_post_emb.transpose(1, 2) |
| | vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1,2)).transpose(1,2) |
| | recon = self.decoder(vq_post_emb.transpose(1, 2), vq=False)[0].squeeze() |
| | |
| | |
| | return recon |
| |
|
| | class WavLM_feat(object): |
| | ''' |
| | Load WavLM |
| | ''' |
| | def __init__(self, device): |
| | self.wavlm = self._reload_wavLM_large(device=device) |
| |
|
| | def __call__(self, wav): |
| | T = wav.shape[-1] |
| | wav = wav.reshape(-1, T) |
| | with torch.no_grad(): |
| | feat = self.wavlm.extract_features(wav, output_layer=6, ret_layer_results=False)[0] |
| | B, T, D = feat.shape |
| | feat = torch.reshape(feat, (-1, D)) |
| |
|
| | return feat |
| |
|
| | def _reload_wavLM_large(self, path="/home/bykang/WavLM-Large.pt", device: Optional[torch.device] = None): |
| | cpt = torch.load(path, map_location="cpu") |
| | cfg = WavLMConfig(cpt['cfg']) |
| | wavLM = WavLM(cfg) |
| | wavLM.load_state_dict(cpt['model']) |
| | wavLM.eval() |
| | if device != None: |
| | wavLM = wavLM.to(device) |
| | for p in wavLM.parameters(): |
| | p.requires_grad = False |
| | print('successful to reload wavLM', path) |
| | return wavLM |
| |
|
| | def get_firstchannel_read(path, fs=16000): |
| | ''' |
| | Get first channel of the wav |
| | ''' |
| | wave_data, sr = sf.read(path) |
| | if sr != fs: |
| | if len(wave_data.shape) != 1: |
| | wave_data = wave_data.transpose((1, 0)) |
| | wave_data = librosa.resample(wave_data, orig_sr=sr, target_sr=fs) |
| | if len(wave_data.shape) != 1: |
| | wave_data = wave_data.transpose((1, 0)) |
| | if len(wave_data.shape) > 1: |
| | wave_data = wave_data[:, 0] |
| | return wave_data |
| |
|
| | def load_obj(obj, device): |
| | ''' |
| | Offload tensor object in obj to cuda device |
| | ''' |
| | def cuda(obj): |
| | return obj.to(device) if isinstance(obj, torch.Tensor) else obj |
| | |
| | if isinstance(obj, dict): |
| | return {key: load_obj(obj[key], device) for key in obj} |
| | elif isinstance(obj, list): |
| | return [load_obj(val, device) for val in obj] |
| | else: |
| | return cuda(obj) |
| |
|
| | def run(args): |
| | LOCAL_RANK = int(os.environ['LOCAL_RANK']) |
| | WORLD_SIZE = int(os.environ['WORLD_SIZE']) |
| | WORLD_RANK = int(os.environ['RANK']) |
| | dist.init_process_group(args.backend, rank=WORLD_RANK, world_size=WORLD_SIZE) |
| | torch.cuda.set_device(LOCAL_RANK) |
| | device = torch.device('cuda', LOCAL_RANK) |
| | print(f"[{os.getpid()}] using device: {device}", torch.cuda.current_device(), "local rank", LOCAL_RANK) |
| |
|
| | with open(args.conf, "r") as f: |
| | conf = yaml.load(f, Loader=yaml.FullLoader) |
| |
|
| | |
| | if conf["task"]=="AEC": |
| | data_reader = DataReaderAEC(**conf["datareader"]) |
| | elif conf["task"]=="TSE": |
| | data_reader = DataReaderTSE(**conf["datareader"]) |
| | else: |
| | data_reader = DataReader(**conf["datareader"]) |
| |
|
| | |
| | codec = Encodec(device) |
| | wavlm_feat = WavLM_feat(device) |
| |
|
| | |
| | nnet = model(**conf["nnet_conf"]) |
| | cpt_fname = Path(conf["test"]["checkpoint"]) |
| | cpt = torch.load(cpt_fname, map_location="cpu") |
| |
|
| | nnet = nnet.to(device) |
| | nnet = DistributedDataParallel(nnet, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, find_unused_parameters=True) |
| | nnet.load_state_dict(cpt["model_state_dict"]) |
| | nnet.eval() |
| |
|
| | |
| | if conf["task"]=="AEC": |
| | if not os.path.exists(conf["save"]["feat_dir"]+"/mic"): |
| | os.makedirs(conf["save"]["feat_dir"]+"/mic") |
| | if not os.path.exists(conf["save"]["feat_dir"]+"/ref"): |
| | os.makedirs(conf["save"]["feat_dir"]+"/ref") |
| | elif conf["task"]=="TSE": |
| | if not os.path.exists(conf["save"]["feat_dir"]+"/mic"): |
| | os.makedirs(conf["save"]["feat_dir"]+"/mic") |
| | if not os.path.exists(conf["save"]["feat_dir"]+"/ref"): |
| | os.makedirs(conf["save"]["feat_dir"]+"/ref") |
| | else: |
| | if not os.path.exists(conf["save"]["feat_dir"]): |
| | os.makedirs(conf["save"]["feat_dir"]) |
| | |
| | if not os.path.exists(conf["save"]["wav_dir"]): |
| | os.makedirs(conf["save"]["wav_dir"]) |
| |
|
| | |
| | if_feat_too = conf["test"]["infer_feat_too"] |
| |
|
| | origin_feat_dir = conf["save"]["feat_dir"] |
| | origin_wav_dir = conf["save"]["wav_dir"] |
| | |
| | last_feat_dir = origin_feat_dir |
| | last_wav_dir = origin_wav_dir |
| |
|
| | for inference_time in range(conf["test"]["inference_time"]): |
| | |
| | if inference_time > 0: |
| | feat_dir = origin_feat_dir + "inference" + str(inference_time) |
| | wav_dir = origin_wav_dir + "inference" + str(inference_time) |
| | else: |
| | feat_dir = origin_feat_dir |
| | wav_dir = origin_wav_dir |
| | |
| | if not os.path.exists(feat_dir): |
| | os.makedirs(feat_dir) |
| | if not os.path.exists(wav_dir): |
| | os.makedirs(wav_dir) |
| |
|
| | with torch.no_grad(): |
| | |
| | if if_feat_too ==True or inference_time>0: |
| | for egs in tqdm(data_reader): |
| | egs = load_obj(egs, device) |
| | |
| | if conf["task"]=="AEC" or conf["task"]=="TSE": |
| | if inference_time > 0: |
| | mic_path = last_wav_dir + '/' + egs["mic_name"] + ".wav" |
| | egs["mic"] = torch.from_numpy(get_firstchannel_read(mic_path).astype(np.float32)).unsqueeze(0).to(device) |
| | else: |
| | egs["mic"]=egs["mic"].contiguous() |
| | egs["ref"]=egs["ref"].contiguous() |
| |
|
| | feat_mic = wavlm_feat(egs["mic"]) |
| | out_mic = feat_mic.detach().squeeze(0).cpu().numpy() |
| | |
| | if not os.path.exists(os.path.join(feat_dir, "mic")): |
| | os.makedirs(os.path.join(feat_dir, "mic")) |
| | np.save(os.path.join(feat_dir, "mic", egs["mic_name"]), out_mic) |
| | |
| | |
| | if inference_time == 0: |
| | feat_ref = wavlm_feat(egs["ref"]) |
| | out_ref = feat_ref.detach().squeeze(0).cpu().numpy() |
| | np.save(os.path.join(origin_feat_dir, "ref", egs["ref_name"]), out_ref) |
| |
|
| | torch.cuda.empty_cache() |
| |
|
| | else: |
| | if inference_time > 0: |
| | mix_path = last_wav_dir + '/' + egs["name"] + ".wav" |
| | egs["mix"] = torch.from_numpy(get_firstchannel_read(mix_path).astype(np.float32)).unsqueeze(0).to(device) |
| | else: |
| | egs["mix"]=egs["mix"].contiguous() |
| | |
| | feat = wavlm_feat(egs["mix"]) |
| | out = feat.detach().squeeze(0).cpu().numpy() |
| | np.save(os.path.join(feat_dir, egs["name"]), out) |
| | |
| | |
| | for egs in tqdm(data_reader): |
| | egs = load_obj(egs, device) |
| | sr = 16000 |
| | |
| | if conf["task"] == "AEC": |
| | |
| | feat_path_mic = os.path.join(feat_dir, "mic", egs["mic_name"]) + ".npy" |
| | feat_path_ref = os.path.join(origin_feat_dir, "ref", egs["ref_name"]) + ".npy" |
| |
|
| | feat_mic = torch.from_numpy(np.load(feat_path_mic)).unsqueeze(0) |
| | feat_ref = torch.from_numpy(np.load(feat_path_ref)).unsqueeze(0) |
| |
|
| | |
| | if inference_time > 0: |
| | est = nnet(feat_mic) |
| | else: |
| | est = nnet(feat_mic, feat_ref) |
| |
|
| | |
| | max, max_indices_1 = torch.max(est[1], dim=1) |
| | recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy() |
| |
|
| | |
| | target_path = os.path.join(wav_dir, egs["mic_name"] + ".wav") |
| | print(target_path) |
| | sf.write(target_path , recon_1, sr) |
| | |
| | elif conf["task"] == "TSE" : |
| | |
| | feat_path_mic = os.path.join(feat_dir, "mic", egs["mic_name"]) + ".npy" |
| | feat_path_ref = os.path.join(origin_feat_dir, "ref", egs["ref_name"]) + ".npy" |
| |
|
| | feat_mic = torch.from_numpy(np.load(feat_path_mic)).unsqueeze(0) |
| | feat_ref = torch.from_numpy(np.load(feat_path_ref)).unsqueeze(0) |
| |
|
| | |
| | if_keep_ref = True |
| |
|
| | if inference_time>0 and if_keep_ref== False: |
| | est = nnet(feat_mic) |
| | else: |
| | est = nnet(feat_mic, feat_ref) |
| | |
| | |
| | max, max_indices_1 = torch.max(est[0], dim=1) |
| | recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy() |
| |
|
| | |
| | target_path = os.path.join(wav_dir, egs["mic_name"] + ".wav") |
| | print(target_path) |
| | sf.write(target_path , recon_1, sr) |
| | |
| | elif conf["task"] == "PLC": |
| | |
| | feat_path = os.path.join(feat_dir, egs["name"]) + ".npy" |
| | feat = torch.from_numpy(np.load(feat_path)).unsqueeze(0) |
| |
|
| | |
| | est = nnet(feat) |
| | max, max_indices_1 = torch.max(est[1], dim=1) |
| | recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy() |
| |
|
| | |
| | target_path = os.path.join(wav_dir, egs["name"] + ".wav") |
| | print(target_path) |
| | sf.write(target_path , recon_1, sr) |
| | |
| | elif conf["task"] == "SS": |
| | |
| | feat_path = os.path.join(feat_dir, egs["name"]) + ".npy" |
| | feat = torch.from_numpy(np.load(feat_path)).unsqueeze(0) |
| | |
| | |
| | est = nnet(feat) |
| | max, max_indices_1 = torch.max(est[1], dim=1) |
| | recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy() |
| |
|
| | target_path_1 = os.path.join(wav_dir, egs["name"] + ".wav") |
| | sf.write(target_path_1 , recon_1, sr) |
| | |
| | |
| | if inference_time > 0: |
| | origin_feat_path = os.path.join(origin_feat_dir, egs["name"]) + ".npy" |
| | origin_feat = torch.from_numpy(np.load(origin_feat_path)).unsqueeze(0) |
| | |
| | est2 = nnet(origin_feat, feat) |
| | max, max_indices_2 = torch.max(est2[1], dim=1) |
| | recon_2 = codec.token2wav(max_indices_2.unsqueeze(0)).squeeze().detach().cpu().numpy() |
| | |
| | if not os.path.exists(last_wav_dir + "s2"): |
| | os.makedirs(last_wav_dir + "s2") |
| | |
| | target_path_2 = os.path.join(last_wav_dir + "s2", egs["name"] + ".wav") |
| | sf.write(target_path_2 , recon_2, sr) |
| | |
| | else: |
| | |
| | feat_path = os.path.join(feat_dir, egs["name"]) + ".npy" |
| | feat = torch.from_numpy(np.load(feat_path)).unsqueeze(0) |
| | |
| | |
| | est = nnet(feat) |
| | max, max_indices_1 = torch.max(est[1], dim=1) |
| | recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy() |
| |
|
| | |
| | target_path = os.path.join(wav_dir, egs["name"] + ".wav") |
| | print(target_path) |
| | sf.write(target_path , recon_1, sr) |
| |
|
| | |
| | last_feat_dir = feat_dir |
| | last_wav_dir = wav_dir |
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser( |
| | description = "Command to test separation model in Pytorch", |
| | formatter_class = argparse.ArgumentDefaultsHelpFormatter) |
| | parser.add_argument("-conf", |
| | type=str, |
| | required=True, |
| | help="Yaml configuration file for training") |
| | parser.add_argument("--backend", |
| | type=str, |
| | default="nccl", |
| | choices=["nccl", "gloo"]) |
| | args = parser.parse_args() |
| | |
| | os.environ["NCCL_DEBUG"] = "INFO" |
| | run(args) |