File size: 7,979 Bytes
705a8fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import torch
from torch import Tensor
from typing import Optional
import math

import warnings

class MelScale(torch.nn.Module):
    r"""Turn a normal STFT into a mel frequency STFT, using a conversion

    matrix.  This uses triangular filter banks.



    User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).



    Args:

        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)

        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)

        f_min (float, optional): Minimum frequency. (Default: ``0.``)

        f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)

        n_stft (int, optional): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`. (Default: ``201``)

        norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band

            (area normalization). (Default: ``None``)

        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)



    See also:

        :py:func:`torchaudio.functional.melscale_fbanks` - The function used to

        generate the filter banks.

    """
    __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']

    def __init__(self,

                 n_mels: int = 128,

                 sample_rate: int = 16000,

                 f_min: float = 0.,

                 f_max: Optional[float] = None,

                 n_stft: int = 201,

                 norm: Optional[str] = None,

                 mel_scale: str = "htk") -> None:
        super(MelScale, self).__init__()
        self.n_mels = n_mels
        self.sample_rate = sample_rate
        self.f_max = f_max if f_max is not None else float(sample_rate // 2)
        self.f_min = f_min
        self.norm = norm
        self.mel_scale = mel_scale

        assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
        fb = melscale_fbanks(
            n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm,
            self.mel_scale)
        self.register_buffer('fb', fb)

    def forward(self, specgram: Tensor) -> Tensor:
        r"""

        Args:

            specgram (Tensor): A spectrogram STFT of dimension (..., freq, time).



        Returns:

            Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).

        """

        # (..., time, freq) dot (freq, n_mels) -> (..., n_mels, time)
        mel_specgram = torch.matmul(specgram.transpose(-1, -2), self.fb).transpose(-1, -2)

        return mel_specgram

def _hz_to_mel(freq: float, mel_scale: str = "htk") -> float:
    r"""Convert Hz to Mels.



    Args:

        freqs (float): Frequencies in Hz

        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)



    Returns:

        mels (float): Frequency in Mels

    """

    if mel_scale not in ['slaney', 'htk']:
        raise ValueError('mel_scale should be one of "htk" or "slaney".')

    if mel_scale == "htk":
        return 2595.0 * math.log10(1.0 + (freq / 700.0))

    # Fill in the linear part
    f_min = 0.0
    f_sp = 200.0 / 3

    mels = (freq - f_min) / f_sp

    # Fill in the log-scale part
    min_log_hz = 1000.0
    min_log_mel = (min_log_hz - f_min) / f_sp
    logstep = math.log(6.4) / 27.0

    if freq >= min_log_hz:
        mels = min_log_mel + math.log(freq / min_log_hz) / logstep

    return mels

def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor:
    """Convert mel bin numbers to frequencies.



    Args:

        mels (Tensor): Mel frequencies

        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)



    Returns:

        freqs (Tensor): Mels converted in Hz

    """

    if mel_scale not in ['slaney', 'htk']:
        raise ValueError('mel_scale should be one of "htk" or "slaney".')

    if mel_scale == "htk":
        return 700.0 * (10.0**(mels / 2595.0) - 1.0)

    # Fill in the linear scale
    f_min = 0.0
    f_sp = 200.0 / 3
    freqs = f_min + f_sp * mels

    # And now the nonlinear scale
    min_log_hz = 1000.0
    min_log_mel = (min_log_hz - f_min) / f_sp
    logstep = math.log(6.4) / 27.0

    log_t = (mels >= min_log_mel)
    freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))

    return freqs

def _create_triangular_filterbank(

        all_freqs: Tensor,

        f_pts: Tensor,

) -> Tensor:
    """Create a triangular filter bank.



    Args:

        all_freqs (Tensor): STFT freq points of size (`n_freqs`).

        f_pts (Tensor): Filter mid points of size (`n_filter`).



    Returns:

        fb (Tensor): The filter bank of size (`n_freqs`, `n_filter`).

    """
    # Adopted from Librosa
    # calculate the difference between each filter mid point and each stft freq point in hertz
    f_diff = f_pts[1:] - f_pts[:-1]  # (n_filter + 1)
    slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1)  # (n_freqs, n_filter + 2)
    # create overlapping triangles
    zero = torch.zeros(1)
    down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1]  # (n_freqs, n_filter)
    up_slopes = slopes[:, 2:] / f_diff[1:]  # (n_freqs, n_filter)
    fb = torch.max(zero, torch.min(down_slopes, up_slopes))

    return fb

def melscale_fbanks(

        n_freqs: int,

        f_min: float,

        f_max: float,

        n_mels: int,

        sample_rate: int,

        norm: Optional[str] = None,

        mel_scale: str = "htk",

) -> Tensor:
    r"""Create a frequency bin conversion matrix.



    Note:

        For the sake of the numerical compatibility with librosa, not all the coefficients

        in the resulting filter bank has magnitude of 1.



        .. image:: https://download.pytorch.org/torchaudio/doc-assets/mel_fbanks.png

           :alt: Visualization of generated filter bank



    Args:

        n_freqs (int): Number of frequencies to highlight/apply

        f_min (float): Minimum frequency (Hz)

        f_max (float): Maximum frequency (Hz)

        n_mels (int): Number of mel filterbanks

        sample_rate (int): Sample rate of the audio waveform

        norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band

            (area normalization). (Default: ``None``)

        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)



    Returns:

        Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)

        meaning number of frequencies to highlight/apply to x the number of filterbanks.

        Each column is a filterbank so that assuming there is a matrix A of

        size (..., ``n_freqs``), the applied result would be

        ``A * melscale_fbanks(A.size(-1), ...)``.



    """

    if norm is not None and norm != "slaney":
        raise ValueError("norm must be one of None or 'slaney'")

    # freq bins
    all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)

    # calculate mel freq bins
    m_min = _hz_to_mel(f_min, mel_scale=mel_scale)
    m_max = _hz_to_mel(f_max, mel_scale=mel_scale)

    m_pts = torch.linspace(m_min, m_max, n_mels + 2)
    f_pts = _mel_to_hz(m_pts, mel_scale=mel_scale)

    # create filterbank
    fb = _create_triangular_filterbank(all_freqs, f_pts)

    if norm is not None and norm == "slaney":
        # Slaney-style mel is scaled to be approx constant energy per channel
        enorm = 2.0 / (f_pts[2:n_mels + 2] - f_pts[:n_mels])
        fb *= enorm.unsqueeze(0)

    if (fb.max(dim=0).values == 0.).any():
        warnings.warn(
            "At least one mel filterbank has all zero values. "
            f"The value for `n_mels` ({n_mels}) may be set too high. "
            f"Or, the value for `n_freqs` ({n_freqs}) may be set too low."
        )

    return fb