JackIsNotInTheBox commited on
Commit
5a63603
·
verified ·
1 Parent(s): e0d7236

Mirror utils.py from nvidia/bigvgan_v2_44khz_128band_512x@95a9d1dc

Browse files
encoders/nvidia/bigvgan_v2_44khz_128band_512x/utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import glob
5
+ import os
6
+ import matplotlib
7
+ import torch
8
+ from torch.nn.utils import weight_norm
9
+
10
+ matplotlib.use("Agg")
11
+ import matplotlib.pylab as plt
12
+ from meldataset import MAX_WAV_VALUE
13
+ from scipy.io.wavfile import write
14
+
15
+
16
+ def plot_spectrogram(spectrogram):
17
+ fig, ax = plt.subplots(figsize=(10, 2))
18
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
19
+ plt.colorbar(im, ax=ax)
20
+
21
+ fig.canvas.draw()
22
+ plt.close()
23
+
24
+ return fig
25
+
26
+
27
+ def plot_spectrogram_clipped(spectrogram, clip_max=2.0):
28
+ fig, ax = plt.subplots(figsize=(10, 2))
29
+ im = ax.imshow(
30
+ spectrogram,
31
+ aspect="auto",
32
+ origin="lower",
33
+ interpolation="none",
34
+ vmin=1e-6,
35
+ vmax=clip_max,
36
+ )
37
+ plt.colorbar(im, ax=ax)
38
+
39
+ fig.canvas.draw()
40
+ plt.close()
41
+
42
+ return fig
43
+
44
+
45
+ def init_weights(m, mean=0.0, std=0.01):
46
+ classname = m.__class__.__name__
47
+ if classname.find("Conv") != -1:
48
+ m.weight.data.normal_(mean, std)
49
+
50
+
51
+ def apply_weight_norm(m):
52
+ classname = m.__class__.__name__
53
+ if classname.find("Conv") != -1:
54
+ weight_norm(m)
55
+
56
+
57
+ def get_padding(kernel_size, dilation=1):
58
+ return int((kernel_size * dilation - dilation) / 2)
59
+
60
+
61
+ def load_checkpoint(filepath, device):
62
+ assert os.path.isfile(filepath)
63
+ print(f"Loading '{filepath}'")
64
+ checkpoint_dict = torch.load(filepath, map_location=device)
65
+ print("Complete.")
66
+ return checkpoint_dict
67
+
68
+
69
+ def save_checkpoint(filepath, obj):
70
+ print(f"Saving checkpoint to {filepath}")
71
+ torch.save(obj, filepath)
72
+ print("Complete.")
73
+
74
+
75
+ def scan_checkpoint(cp_dir, prefix, renamed_file=None):
76
+ # Fallback to original scanning logic first
77
+ pattern = os.path.join(cp_dir, prefix + "????????")
78
+ cp_list = glob.glob(pattern)
79
+
80
+ if len(cp_list) > 0:
81
+ last_checkpoint_path = sorted(cp_list)[-1]
82
+ print(f"[INFO] Resuming from checkpoint: '{last_checkpoint_path}'")
83
+ return last_checkpoint_path
84
+
85
+ # If no pattern-based checkpoints are found, check for renamed file
86
+ if renamed_file:
87
+ renamed_path = os.path.join(cp_dir, renamed_file)
88
+ if os.path.isfile(renamed_path):
89
+ print(f"[INFO] Resuming from renamed checkpoint: '{renamed_file}'")
90
+ return renamed_path
91
+
92
+ return None
93
+
94
+
95
+ def save_audio(audio, path, sr):
96
+ # wav: torch with 1d shape
97
+ audio = audio * MAX_WAV_VALUE
98
+ audio = audio.cpu().numpy().astype("int16")
99
+ write(path, sr, audio)