| import torch |
| import torch.nn as nn |
| import torchaudio.transforms as T |
| import numpy as np |
|
|
|
|
| class MultiViewSpectrogram(nn.Module): |
| def __init__(self, sample_rate=16000, n_mels=80, hop_length=160): |
| super().__init__() |
| |
| self.win_lengths = [368, 736, 1488] |
| self.transforms = nn.ModuleList() |
|
|
| for win_len in self.win_lengths: |
| n_fft = 2 ** int(np.ceil(np.log2(win_len))) |
| mel = T.MelSpectrogram( |
| sample_rate=sample_rate, |
| n_fft=n_fft, |
| win_length=win_len, |
| hop_length=hop_length, |
| f_min=27.5, |
| f_max=16000.0, |
| n_mels=n_mels, |
| power=1.0, |
| center=True, |
| ) |
| self.transforms.append(mel) |
|
|
| def forward(self, waveform): |
| specs = [] |
| for transform in self.transforms: |
| |
| s = transform(waveform) |
| s = torch.log(s + 1e-9) |
| specs.append(s) |
| return torch.stack(specs, dim=1) |
|
|
|
|
| def extract_context(spec, center_frame, context=7): |
| |
| channels, n_mels, total_time = spec.shape |
| start = center_frame - context |
| end = center_frame + context + 1 |
|
|
| pad_left = max(0, -start) |
| pad_right = max(0, end - total_time) |
|
|
| if pad_left > 0 or pad_right > 0: |
| spec = torch.nn.functional.pad(spec, (pad_left, pad_right)) |
| start += pad_left |
| end += pad_left |
|
|
| return spec[:, :, start:end] |
|
|