File size: 2,953 Bytes
eca55dc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 | import torch
import torchaudio
def get_spectrogram_shape(waveform_len, hop_length=1250, center=True):
# Simulation of torchaudio MelSpectrogram shape calculation
# If center=True, it pads the signal.
# Output time steps = input_samples // hop_length + 1
return waveform_len // hop_length + 1
def calculate_required_length(current_len, hop_length, patch_time_dim):
# We need spec_len % (2 * patch_time_dim) == 0
# spec_len = current_len // hop_length + 1
# Let target_spec_len be the next multiple of (2 * patch_time_dim)
# target_spec_len = ceil(spec_len / (2 * patch_time_dim)) * (2 * patch_time_dim)
# Then we need waveform_len such that waveform_len // hop_length + 1 = target_spec_len
# waveform_len // hop_length = target_spec_len - 1
# waveform_len = (target_spec_len - 1) * hop_length
# But wait, we can just pad the waveform to be larger.
# Any waveform_len in range [ (target_spec_len-1)*hop_length, target_spec_len*hop_length - 1 ] might work?
# Let's just pick one: waveform_len = (target_spec_len - 1) * hop_length
spec_len = current_len // hop_length + 1
block_size = 2 * patch_time_dim
if spec_len % block_size == 0:
target_spec_len = spec_len
else:
target_spec_len = (spec_len // block_size + 1) * block_size
# Reverse to waveform length
# We want (target_wave_len // hop_length + 1) == target_spec_len
# target_wave_len // hop_length = target_spec_len - 1
# target_wave_len = (target_spec_len - 1) * hop_length
return target_spec_len, target_spec_len * hop_length # Approximate, let's verify
def test_shapes():
hop_length = 1250
patch_time_dim = 16
lengths = [32000, 48000, 320000, 12345]
mel = torchaudio.transforms.MelSpectrogram(
sample_rate=32000,
n_fft=4096,
win_length=4096,
hop_length=hop_length,
n_mels=128,
center=True,
)
print(f"Testing with hop_length={hop_length}, patch_time_dim={patch_time_dim}")
for length in lengths:
wave = torch.randn(1, length)
spec = mel(wave)
spec_len = spec.shape[-1]
print(f"Wave: {length}, Spec: {spec_len}")
# Calculate required
target_spec_len, target_wave_len = calculate_required_length(
length, hop_length, patch_time_dim
)
# Verify
wave_pad = torch.randn(1, target_wave_len)
spec_pad = mel(wave_pad)
spec_pad_len = spec_pad.shape[-1]
print(f" Target Spec: {target_spec_len}, Target Wave: {target_wave_len}")
print(f" Actual Spec: {spec_pad_len}")
print(f" Even patches? {spec_pad_len / patch_time_dim} (Time patches)")
print(
f" Even time patches condition: {(spec_pad_len // patch_time_dim) % 2 == 0}"
)
if spec_pad_len != target_spec_len:
print(" MISMATCH!")
if __name__ == "__main__":
test_shapes()
|