File size: 4,703 Bytes
742e266
 
 
 
 
 
 
 
 
 
 
 
 
0020ddc
742e266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Created: 2026-02-18
# Purpose: Original/residual mel-spectrogram visualization (matplotlib)
# Dependencies: matplotlib, numpy, torch

"""Mel-spectrogram comparison visualization of original audio and analysis results."""

import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from config import SR, N_FFT, HOP_LENGTH

N_MELS = 128


def _compute_mel_spectrogram(audio_1d: np.ndarray) -> np.ndarray:
    """1D audio -> mel spectrogram (dB scale)."""
    from scipy import signal as sig

    # STFT
    _, _, Zxx = sig.stft(audio_1d, fs=SR, window='hann',
                         nperseg=N_FFT, noverlap=N_FFT - HOP_LENGTH)
    mag = np.abs(Zxx)

    # Mel filterbank
    n_freqs = N_FFT // 2 + 1
    def hz_to_mel(f): return 2595.0 * np.log10(1.0 + f / 700.0)
    def mel_to_hz(m): return 700.0 * (10.0 ** (m / 2595.0) - 1.0)

    mel_pts = np.linspace(hz_to_mel(0), hz_to_mel(SR / 2), N_MELS + 2)
    hz_pts = mel_to_hz(mel_pts)
    freqs = np.linspace(0, SR / 2, n_freqs)

    fb = np.zeros((n_freqs, N_MELS), dtype=np.float32)
    for i in range(N_MELS):
        lo, mid, hi = hz_pts[i], hz_pts[i + 1], hz_pts[i + 2]
        for j in range(n_freqs):
            if lo <= freqs[j] <= mid and (mid - lo) > 0:
                fb[j, i] = (freqs[j] - lo) / (mid - lo)
            elif mid < freqs[j] <= hi and (hi - mid) > 0:
                fb[j, i] = (hi - freqs[j]) / (hi - mid)

    mel = fb.T @ (mag ** 2)
    mel_db = 10.0 * np.log10(np.maximum(mel, 1e-10))
    max_val = np.max(mel_db)
    mel_db = np.maximum(mel_db, max_val - 80.0)

    return mel_db


def plot_spectrograms(original_mono: np.ndarray,

                      residual_mono: np.ndarray = None) -> plt.Figure:
    """Return mel-spectrogram figure (1-panel or 2-panel).



    Args:

        original_mono: 1D numpy array (mono original)

        residual_mono: 1D numpy array (Demucs residual), optional



    Returns:

        matplotlib Figure

    """
    max_samples = 30 * SR
    orig = original_mono[:max_samples]
    mel_orig = _compute_mel_spectrogram(orig)

    if residual_mono is not None:
        # 2-panel: Original vs Residual
        res = residual_mono[:min(len(residual_mono), max_samples)]
        mel_res = _compute_mel_spectrogram(res)

        fig, axes = plt.subplots(1, 2, figsize=(14, 4), constrained_layout=True)
        t_orig = np.linspace(0, len(orig) / SR, mel_orig.shape[1])
        t_res = np.linspace(0, len(res) / SR, mel_res.shape[1])

        im0 = axes[0].imshow(mel_orig, aspect='auto', origin='lower',
                             extent=[0, t_orig[-1], 0, SR / 2000],
                             cmap='magma', interpolation='bilinear')
        axes[0].set_title('Original', fontsize=12, fontweight='bold')
        axes[0].set_xlabel('Time (s)')
        axes[0].set_ylabel('Frequency (kHz)')
        axes[0].set_ylim(0, 16)
        plt.colorbar(im0, ax=axes[0], label='dB', fraction=0.046, pad=0.04)

        im1 = axes[1].imshow(mel_res, aspect='auto', origin='lower',
                             extent=[0, t_res[-1], 0, SR / 2000],
                             cmap='magma', interpolation='bilinear')
        axes[1].set_title('Demucs Residual', fontsize=12, fontweight='bold')
        axes[1].set_xlabel('Time (s)')
        axes[1].set_ylabel('Frequency (kHz)')
        axes[1].set_ylim(0, 16)
        plt.colorbar(im1, ax=axes[1], label='dB', fraction=0.046, pad=0.04)

        fig.patch.set_facecolor('#1a1a2e')
        for ax in axes:
            ax.set_facecolor('#16213e')
            ax.tick_params(colors='white')
            ax.xaxis.label.set_color('white')
            ax.yaxis.label.set_color('white')
            ax.title.set_color('white')
    else:
        # 1-panel: Original only
        fig, ax = plt.subplots(1, 1, figsize=(14, 4), constrained_layout=True)
        t_orig = np.linspace(0, len(orig) / SR, mel_orig.shape[1])

        im0 = ax.imshow(mel_orig, aspect='auto', origin='lower',
                        extent=[0, t_orig[-1], 0, SR / 2000],
                        cmap='magma', interpolation='bilinear')
        ax.set_title('Mel Spectrogram', fontsize=12, fontweight='bold')
        ax.set_xlabel('Time (s)')
        ax.set_ylabel('Frequency (kHz)')
        ax.set_ylim(0, 16)
        plt.colorbar(im0, ax=ax, label='dB', fraction=0.046, pad=0.04)

        fig.patch.set_facecolor('#1a1a2e')
        ax.set_facecolor('#16213e')
        ax.tick_params(colors='white')
        ax.xaxis.label.set_color('white')
        ax.yaxis.label.set_color('white')
        ax.title.set_color('white')

    return fig