File size: 7,979 Bytes
705a8fd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 | import torch
from torch import Tensor
from typing import Optional
import math
import warnings
class MelScale(torch.nn.Module):
r"""Turn a normal STFT into a mel frequency STFT, using a conversion
matrix. This uses triangular filter banks.
User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
Args:
n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
f_min (float, optional): Minimum frequency. (Default: ``0.``)
f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
n_stft (int, optional): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`. (Default: ``201``)
norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
See also:
:py:func:`torchaudio.functional.melscale_fbanks` - The function used to
generate the filter banks.
"""
__constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
def __init__(self,
n_mels: int = 128,
sample_rate: int = 16000,
f_min: float = 0.,
f_max: Optional[float] = None,
n_stft: int = 201,
norm: Optional[str] = None,
mel_scale: str = "htk") -> None:
super(MelScale, self).__init__()
self.n_mels = n_mels
self.sample_rate = sample_rate
self.f_max = f_max if f_max is not None else float(sample_rate // 2)
self.f_min = f_min
self.norm = norm
self.mel_scale = mel_scale
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
fb = melscale_fbanks(
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm,
self.mel_scale)
self.register_buffer('fb', fb)
def forward(self, specgram: Tensor) -> Tensor:
r"""
Args:
specgram (Tensor): A spectrogram STFT of dimension (..., freq, time).
Returns:
Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
"""
# (..., time, freq) dot (freq, n_mels) -> (..., n_mels, time)
mel_specgram = torch.matmul(specgram.transpose(-1, -2), self.fb).transpose(-1, -2)
return mel_specgram
def _hz_to_mel(freq: float, mel_scale: str = "htk") -> float:
r"""Convert Hz to Mels.
Args:
freqs (float): Frequencies in Hz
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
Returns:
mels (float): Frequency in Mels
"""
if mel_scale not in ['slaney', 'htk']:
raise ValueError('mel_scale should be one of "htk" or "slaney".')
if mel_scale == "htk":
return 2595.0 * math.log10(1.0 + (freq / 700.0))
# Fill in the linear part
f_min = 0.0
f_sp = 200.0 / 3
mels = (freq - f_min) / f_sp
# Fill in the log-scale part
min_log_hz = 1000.0
min_log_mel = (min_log_hz - f_min) / f_sp
logstep = math.log(6.4) / 27.0
if freq >= min_log_hz:
mels = min_log_mel + math.log(freq / min_log_hz) / logstep
return mels
def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor:
"""Convert mel bin numbers to frequencies.
Args:
mels (Tensor): Mel frequencies
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
Returns:
freqs (Tensor): Mels converted in Hz
"""
if mel_scale not in ['slaney', 'htk']:
raise ValueError('mel_scale should be one of "htk" or "slaney".')
if mel_scale == "htk":
return 700.0 * (10.0**(mels / 2595.0) - 1.0)
# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mels
# And now the nonlinear scale
min_log_hz = 1000.0
min_log_mel = (min_log_hz - f_min) / f_sp
logstep = math.log(6.4) / 27.0
log_t = (mels >= min_log_mel)
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))
return freqs
def _create_triangular_filterbank(
all_freqs: Tensor,
f_pts: Tensor,
) -> Tensor:
"""Create a triangular filter bank.
Args:
all_freqs (Tensor): STFT freq points of size (`n_freqs`).
f_pts (Tensor): Filter mid points of size (`n_filter`).
Returns:
fb (Tensor): The filter bank of size (`n_freqs`, `n_filter`).
"""
# Adopted from Librosa
# calculate the difference between each filter mid point and each stft freq point in hertz
f_diff = f_pts[1:] - f_pts[:-1] # (n_filter + 1)
slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_filter + 2)
# create overlapping triangles
zero = torch.zeros(1)
down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_filter)
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_filter)
fb = torch.max(zero, torch.min(down_slopes, up_slopes))
return fb
def melscale_fbanks(
n_freqs: int,
f_min: float,
f_max: float,
n_mels: int,
sample_rate: int,
norm: Optional[str] = None,
mel_scale: str = "htk",
) -> Tensor:
r"""Create a frequency bin conversion matrix.
Note:
For the sake of the numerical compatibility with librosa, not all the coefficients
in the resulting filter bank has magnitude of 1.
.. image:: https://download.pytorch.org/torchaudio/doc-assets/mel_fbanks.png
:alt: Visualization of generated filter bank
Args:
n_freqs (int): Number of frequencies to highlight/apply
f_min (float): Minimum frequency (Hz)
f_max (float): Maximum frequency (Hz)
n_mels (int): Number of mel filterbanks
sample_rate (int): Sample rate of the audio waveform
norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
Returns:
Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
meaning number of frequencies to highlight/apply to x the number of filterbanks.
Each column is a filterbank so that assuming there is a matrix A of
size (..., ``n_freqs``), the applied result would be
``A * melscale_fbanks(A.size(-1), ...)``.
"""
if norm is not None and norm != "slaney":
raise ValueError("norm must be one of None or 'slaney'")
# freq bins
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
# calculate mel freq bins
m_min = _hz_to_mel(f_min, mel_scale=mel_scale)
m_max = _hz_to_mel(f_max, mel_scale=mel_scale)
m_pts = torch.linspace(m_min, m_max, n_mels + 2)
f_pts = _mel_to_hz(m_pts, mel_scale=mel_scale)
# create filterbank
fb = _create_triangular_filterbank(all_freqs, f_pts)
if norm is not None and norm == "slaney":
# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (f_pts[2:n_mels + 2] - f_pts[:n_mels])
fb *= enorm.unsqueeze(0)
if (fb.max(dim=0).values == 0.).any():
warnings.warn(
"At least one mel filterbank has all zero values. "
f"The value for `n_mels` ({n_mels}) may be set too high. "
f"Or, the value for `n_freqs` ({n_freqs}) may be set too low."
)
return fb |