| import os |
| import errno |
| import numpy as np |
|
|
| from copy import deepcopy |
| from miscc.config import cfg |
| from scipy.io.wavfile import write |
| from torch.nn import init |
| import torch |
| import torch.nn as nn |
| import torchvision.utils as vutils |
| from wavefile import WaveWriter, Format |
| import RT60 |
| from multiprocessing import Pool |
|
|
|
|
| |
| def KL_loss(mu, logvar): |
| |
| KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) |
| KLD = torch.mean(KLD_element).mul_(-0.5) |
| return KLD |
|
|
|
|
| def compute_discriminator_loss(netD, real_RIRs, fake_RIRs, |
| real_labels, fake_labels, |
| conditions, gpus): |
| criterion = nn.BCELoss() |
| batch_size = real_RIRs.size(0) |
| cond = conditions.detach() |
| fake = fake_RIRs.detach() |
| real_features = nn.parallel.data_parallel(netD, (real_RIRs), gpus) |
| fake_features = nn.parallel.data_parallel(netD, (fake), gpus) |
| |
| |
| inputs = (real_features, cond) |
| real_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) |
| errD_real = criterion(real_logits, real_labels) |
| |
| inputs = (real_features[:(batch_size-1)], cond[1:]) |
| wrong_logits = \ |
| nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) |
| errD_wrong = criterion(wrong_logits, fake_labels[1:]) |
| |
| inputs = (fake_features, cond) |
| fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) |
| errD_fake = criterion(fake_logits, fake_labels) |
|
|
| if netD.get_uncond_logits is not None: |
| real_logits = \ |
| nn.parallel.data_parallel(netD.get_uncond_logits, |
| (real_features), gpus) |
| fake_logits = \ |
| nn.parallel.data_parallel(netD.get_uncond_logits, |
| (fake_features), gpus) |
| uncond_errD_real = criterion(real_logits, real_labels) |
| uncond_errD_fake = criterion(fake_logits, fake_labels) |
| |
| errD = ((errD_real + uncond_errD_real) / 2. + |
| (errD_fake + errD_wrong + uncond_errD_fake) / 3.) |
| errD_real = (errD_real + uncond_errD_real) / 2. |
| errD_fake = (errD_fake + uncond_errD_fake) / 2. |
| else: |
| errD = errD_real + (errD_fake + errD_wrong) * 0.5 |
| return errD, errD_real.data, errD_wrong.data, errD_fake.data |
| |
|
|
|
|
|
|
| def compute_generator_loss(epoch,netD,real_RIRs, fake_RIRs, real_labels, conditions, gpus): |
| criterion = nn.BCELoss() |
| loss = nn.L1Loss() |
| loss1 = nn.MSELoss() |
| RT_error = 0 |
| |
| |
| |
|
|
| cond = conditions.detach() |
| fake_features = nn.parallel.data_parallel(netD, (fake_RIRs), gpus) |
| |
| inputs = (fake_features, cond) |
| fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) |
| MSE_error = loss(real_RIRs,fake_RIRs) |
| MSE_error1 = loss1(real_RIRs,fake_RIRs) |
| sample_size = real_RIRs.size()[0] |
| channel = 12 |
| fs = 16000 |
| rn = np.random.randint(sample_size-(channel*2)) |
| real_wave = np.array(real_RIRs[rn:rn+channel].to("cpu").detach()) |
| real_wave = real_wave.reshape(channel,4096) |
| fake_wave = np.array(fake_RIRs[rn:rn+channel].to("cpu").detach()) |
| fake_wave = fake_wave.reshape(channel,4096) |
|
|
| pool = Pool(processes=12) |
| |
| results =[] |
| for n in range(channel): |
| results.append(pool.apply_async(RT60.t60_parallel, args=(n,real_wave,fake_wave,fs,))) |
| |
| T60_error =0 |
| for result in results: |
| T60_error = T60_error + result.get() |
|
|
| RT_error = T60_error/channel |
| |
| pool.close() |
| pool.join() |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| errD_fake = criterion(fake_logits, real_labels) + 5* 4096 * MSE_error1 + 40 * RT_error |
| if netD.get_uncond_logits is not None: |
| fake_logits = \ |
| nn.parallel.data_parallel(netD.get_uncond_logits, |
| (fake_features), gpus) |
| uncond_errD_fake = criterion(fake_logits, real_labels) |
| errD_fake += uncond_errD_fake |
| return errD_fake, MSE_error,RT_error |
|
|
|
|
| |
| def weights_init(m): |
| classname = m.__class__.__name__ |
| if classname.find('Conv') != -1: |
| m.weight.data.normal_(0.0, 0.02) |
| elif classname.find('BatchNorm') != -1: |
| m.weight.data.normal_(1.0, 0.02) |
| m.bias.data.fill_(0) |
| elif classname.find('Linear') != -1: |
| m.weight.data.normal_(0.0, 0.02) |
| if m.bias is not None: |
| m.bias.data.fill_(0.0) |
|
|
|
|
| |
| def save_RIR_results(data_RIR, fake, epoch, RIR_dir): |
| num = cfg.VIS_COUNT |
| fake = fake[0:num] |
| |
| if data_RIR is not None: |
| data_RIR = data_RIR[0:num] |
| for i in range(num): |
| |
| real_RIR_path = RIR_dir+"/real_sample"+str(i)+".wav" |
| fake_RIR_path = RIR_dir+"/fake_sample"+str(i)+"_epoch_"+str(epoch)+".wav" |
| fs =16000 |
|
|
| real_IR = np.array(data_RIR[i].to("cpu").detach()) |
| fake_IR = np.array(fake[i].to("cpu").detach()) |
| |
| |
| |
| |
| r = WaveWriter(real_RIR_path, channels=1, samplerate=fs) |
| r.write(np.array(real_IR)) |
| f = WaveWriter(fake_RIR_path, channels=1, samplerate=fs) |
| f.write(np.array(fake_IR)) |
|
|
|
|
| |
| |
|
|
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| else: |
| for i in range(num): |
| |
| fake_RIR_path = RIR_dir+"/small_fake_sample"+str(i)+"_epoch_"+str(epoch)+".wav" |
| fs =16000 |
| fake_IR = np.array(fake[i].to("cpu").detach()) |
| f = WaveWriter(fake_RIR_path, channels=1, samplerate=fs) |
| f.write(np.array(fake_IR)) |
| |
| |
|
|
| |
| |
| |
|
|
|
|
| def save_model(netG, netD, epoch, model_dir): |
| torch.save( |
| netG.state_dict(), |
| '%s/netG_epoch_%d.pth' % (model_dir, epoch)) |
| torch.save( |
| netD.state_dict(), |
| '%s/netD_epoch_last.pth' % (model_dir)) |
| |
|
|
|
|
| def mkdir_p(path): |
| try: |
| os.makedirs(path) |
| except OSError as exc: |
| if exc.errno == errno.EEXIST and os.path.isdir(path): |
| pass |
| else: |
| raise |
|
|