# Created: 2026-02-18 # Purpose: Original/residual mel-spectrogram visualization (matplotlib) # Dependencies: matplotlib, numpy, torch """Mel-spectrogram comparison visualization of original audio and analysis results.""" import numpy as np import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt from config import SR, N_FFT, HOP_LENGTH N_MELS = 128 def _compute_mel_spectrogram(audio_1d: np.ndarray) -> np.ndarray: """1D audio -> mel spectrogram (dB scale).""" from scipy import signal as sig # STFT _, _, Zxx = sig.stft(audio_1d, fs=SR, window='hann', nperseg=N_FFT, noverlap=N_FFT - HOP_LENGTH) mag = np.abs(Zxx) # Mel filterbank n_freqs = N_FFT // 2 + 1 def hz_to_mel(f): return 2595.0 * np.log10(1.0 + f / 700.0) def mel_to_hz(m): return 700.0 * (10.0 ** (m / 2595.0) - 1.0) mel_pts = np.linspace(hz_to_mel(0), hz_to_mel(SR / 2), N_MELS + 2) hz_pts = mel_to_hz(mel_pts) freqs = np.linspace(0, SR / 2, n_freqs) fb = np.zeros((n_freqs, N_MELS), dtype=np.float32) for i in range(N_MELS): lo, mid, hi = hz_pts[i], hz_pts[i + 1], hz_pts[i + 2] for j in range(n_freqs): if lo <= freqs[j] <= mid and (mid - lo) > 0: fb[j, i] = (freqs[j] - lo) / (mid - lo) elif mid < freqs[j] <= hi and (hi - mid) > 0: fb[j, i] = (hi - freqs[j]) / (hi - mid) mel = fb.T @ (mag ** 2) mel_db = 10.0 * np.log10(np.maximum(mel, 1e-10)) max_val = np.max(mel_db) mel_db = np.maximum(mel_db, max_val - 80.0) return mel_db def plot_spectrograms(original_mono: np.ndarray, residual_mono: np.ndarray = None) -> plt.Figure: """Return mel-spectrogram figure (1-panel or 2-panel). Args: original_mono: 1D numpy array (mono original) residual_mono: 1D numpy array (Demucs residual), optional Returns: matplotlib Figure """ max_samples = 30 * SR orig = original_mono[:max_samples] mel_orig = _compute_mel_spectrogram(orig) if residual_mono is not None: # 2-panel: Original vs Residual res = residual_mono[:min(len(residual_mono), max_samples)] mel_res = _compute_mel_spectrogram(res) fig, axes = plt.subplots(1, 2, figsize=(14, 4), constrained_layout=True) t_orig = np.linspace(0, len(orig) / SR, mel_orig.shape[1]) t_res = np.linspace(0, len(res) / SR, mel_res.shape[1]) im0 = axes[0].imshow(mel_orig, aspect='auto', origin='lower', extent=[0, t_orig[-1], 0, SR / 2000], cmap='magma', interpolation='bilinear') axes[0].set_title('Original', fontsize=12, fontweight='bold') axes[0].set_xlabel('Time (s)') axes[0].set_ylabel('Frequency (kHz)') axes[0].set_ylim(0, 16) plt.colorbar(im0, ax=axes[0], label='dB', fraction=0.046, pad=0.04) im1 = axes[1].imshow(mel_res, aspect='auto', origin='lower', extent=[0, t_res[-1], 0, SR / 2000], cmap='magma', interpolation='bilinear') axes[1].set_title('Demucs Residual', fontsize=12, fontweight='bold') axes[1].set_xlabel('Time (s)') axes[1].set_ylabel('Frequency (kHz)') axes[1].set_ylim(0, 16) plt.colorbar(im1, ax=axes[1], label='dB', fraction=0.046, pad=0.04) fig.patch.set_facecolor('#1a1a2e') for ax in axes: ax.set_facecolor('#16213e') ax.tick_params(colors='white') ax.xaxis.label.set_color('white') ax.yaxis.label.set_color('white') ax.title.set_color('white') else: # 1-panel: Original only fig, ax = plt.subplots(1, 1, figsize=(14, 4), constrained_layout=True) t_orig = np.linspace(0, len(orig) / SR, mel_orig.shape[1]) im0 = ax.imshow(mel_orig, aspect='auto', origin='lower', extent=[0, t_orig[-1], 0, SR / 2000], cmap='magma', interpolation='bilinear') ax.set_title('Mel Spectrogram', fontsize=12, fontweight='bold') ax.set_xlabel('Time (s)') ax.set_ylabel('Frequency (kHz)') ax.set_ylim(0, 16) plt.colorbar(im0, ax=ax, label='dB', fraction=0.046, pad=0.04) fig.patch.set_facecolor('#1a1a2e') ax.set_facecolor('#16213e') ax.tick_params(colors='white') ax.xaxis.label.set_color('white') ax.yaxis.label.set_color('white') ax.title.set_color('white') return fig