| |
| |
|
|
| import glob |
| import os |
| import matplotlib |
| import torch |
| from torch.nn.utils import weight_norm |
| matplotlib.use("Agg") |
| import matplotlib.pylab as plt |
| from scipy.io.wavfile import write |
|
|
| MAX_WAV_VALUE = 32768.0 |
|
|
|
|
| def plot_spectrogram(spectrogram): |
| fig, ax = plt.subplots(figsize=(10, 2)) |
| im = ax.imshow(spectrogram, aspect="auto", origin="lower", |
| interpolation='none') |
| plt.colorbar(im, ax=ax) |
|
|
| fig.canvas.draw() |
| plt.close() |
|
|
| return fig |
|
|
|
|
| def plot_spectrogram_clipped(spectrogram, clip_max=2.): |
| fig, ax = plt.subplots(figsize=(10, 2)) |
| im = ax.imshow(spectrogram, aspect="auto", origin="lower", |
| interpolation='none', vmin=1e-6, vmax=clip_max) |
| plt.colorbar(im, ax=ax) |
|
|
| fig.canvas.draw() |
| plt.close() |
|
|
| return fig |
|
|
|
|
| def init_weights(m, mean=0.0, std=0.01): |
| classname = m.__class__.__name__ |
| if classname.find("Conv") != -1: |
| m.weight.data.normal_(mean, std) |
|
|
|
|
| def apply_weight_norm(m): |
| classname = m.__class__.__name__ |
| if classname.find("Conv") != -1: |
| weight_norm(m) |
|
|
|
|
| def get_padding(kernel_size, dilation=1): |
| return int((kernel_size*dilation - dilation)/2) |
|
|
|
|
| def load_checkpoint(filepath, device): |
| assert os.path.isfile(filepath) |
| print("Loading '{}'".format(filepath)) |
| checkpoint_dict = torch.load(filepath, map_location=device) |
| print("Complete.") |
| return checkpoint_dict |
|
|
|
|
| def save_checkpoint(filepath, obj): |
| print("Saving checkpoint to {}".format(filepath)) |
| torch.save(obj, filepath) |
| print("Complete.") |
|
|
|
|
| def scan_checkpoint(cp_dir, prefix): |
| pattern = os.path.join(cp_dir, prefix + '????????') |
| cp_list = glob.glob(pattern) |
| if len(cp_list) == 0: |
| return None |
| return sorted(cp_list)[-1] |
|
|
| def save_audio(audio, path, sr): |
| |
| audio = audio * MAX_WAV_VALUE |
| audio = audio.cpu().numpy().astype('int16') |
| write(path, sr, audio) |