Spaces:
Runtime error
Runtime error
| # Created: 2026-02-18 | |
| # Purpose: Original/residual mel-spectrogram visualization (matplotlib) | |
| # Dependencies: matplotlib, numpy, torch | |
| """Mel-spectrogram comparison visualization of original audio and analysis results.""" | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| from config import SR, N_FFT, HOP_LENGTH | |
| N_MELS = 128 | |
| def _compute_mel_spectrogram(audio_1d: np.ndarray) -> np.ndarray: | |
| """1D audio -> mel spectrogram (dB scale).""" | |
| from scipy import signal as sig | |
| # STFT | |
| _, _, Zxx = sig.stft(audio_1d, fs=SR, window='hann', | |
| nperseg=N_FFT, noverlap=N_FFT - HOP_LENGTH) | |
| mag = np.abs(Zxx) | |
| # Mel filterbank | |
| n_freqs = N_FFT // 2 + 1 | |
| def hz_to_mel(f): return 2595.0 * np.log10(1.0 + f / 700.0) | |
| def mel_to_hz(m): return 700.0 * (10.0 ** (m / 2595.0) - 1.0) | |
| mel_pts = np.linspace(hz_to_mel(0), hz_to_mel(SR / 2), N_MELS + 2) | |
| hz_pts = mel_to_hz(mel_pts) | |
| freqs = np.linspace(0, SR / 2, n_freqs) | |
| fb = np.zeros((n_freqs, N_MELS), dtype=np.float32) | |
| for i in range(N_MELS): | |
| lo, mid, hi = hz_pts[i], hz_pts[i + 1], hz_pts[i + 2] | |
| for j in range(n_freqs): | |
| if lo <= freqs[j] <= mid and (mid - lo) > 0: | |
| fb[j, i] = (freqs[j] - lo) / (mid - lo) | |
| elif mid < freqs[j] <= hi and (hi - mid) > 0: | |
| fb[j, i] = (hi - freqs[j]) / (hi - mid) | |
| mel = fb.T @ (mag ** 2) | |
| mel_db = 10.0 * np.log10(np.maximum(mel, 1e-10)) | |
| max_val = np.max(mel_db) | |
| mel_db = np.maximum(mel_db, max_val - 80.0) | |
| return mel_db | |
| def plot_spectrograms(original_mono: np.ndarray, | |
| residual_mono: np.ndarray = None) -> plt.Figure: | |
| """Return mel-spectrogram figure (1-panel or 2-panel). | |
| Args: | |
| original_mono: 1D numpy array (mono original) | |
| residual_mono: 1D numpy array (Demucs residual), optional | |
| Returns: | |
| matplotlib Figure | |
| """ | |
| max_samples = 30 * SR | |
| orig = original_mono[:max_samples] | |
| mel_orig = _compute_mel_spectrogram(orig) | |
| if residual_mono is not None: | |
| # 2-panel: Original vs Residual | |
| res = residual_mono[:min(len(residual_mono), max_samples)] | |
| mel_res = _compute_mel_spectrogram(res) | |
| fig, axes = plt.subplots(1, 2, figsize=(14, 4), constrained_layout=True) | |
| t_orig = np.linspace(0, len(orig) / SR, mel_orig.shape[1]) | |
| t_res = np.linspace(0, len(res) / SR, mel_res.shape[1]) | |
| im0 = axes[0].imshow(mel_orig, aspect='auto', origin='lower', | |
| extent=[0, t_orig[-1], 0, SR / 2000], | |
| cmap='magma', interpolation='bilinear') | |
| axes[0].set_title('Original', fontsize=12, fontweight='bold') | |
| axes[0].set_xlabel('Time (s)') | |
| axes[0].set_ylabel('Frequency (kHz)') | |
| axes[0].set_ylim(0, 16) | |
| plt.colorbar(im0, ax=axes[0], label='dB', fraction=0.046, pad=0.04) | |
| im1 = axes[1].imshow(mel_res, aspect='auto', origin='lower', | |
| extent=[0, t_res[-1], 0, SR / 2000], | |
| cmap='magma', interpolation='bilinear') | |
| axes[1].set_title('Demucs Residual', fontsize=12, fontweight='bold') | |
| axes[1].set_xlabel('Time (s)') | |
| axes[1].set_ylabel('Frequency (kHz)') | |
| axes[1].set_ylim(0, 16) | |
| plt.colorbar(im1, ax=axes[1], label='dB', fraction=0.046, pad=0.04) | |
| fig.patch.set_facecolor('#1a1a2e') | |
| for ax in axes: | |
| ax.set_facecolor('#16213e') | |
| ax.tick_params(colors='white') | |
| ax.xaxis.label.set_color('white') | |
| ax.yaxis.label.set_color('white') | |
| ax.title.set_color('white') | |
| else: | |
| # 1-panel: Original only | |
| fig, ax = plt.subplots(1, 1, figsize=(14, 4), constrained_layout=True) | |
| t_orig = np.linspace(0, len(orig) / SR, mel_orig.shape[1]) | |
| im0 = ax.imshow(mel_orig, aspect='auto', origin='lower', | |
| extent=[0, t_orig[-1], 0, SR / 2000], | |
| cmap='magma', interpolation='bilinear') | |
| ax.set_title('Mel Spectrogram', fontsize=12, fontweight='bold') | |
| ax.set_xlabel('Time (s)') | |
| ax.set_ylabel('Frequency (kHz)') | |
| ax.set_ylim(0, 16) | |
| plt.colorbar(im0, ax=ax, label='dB', fraction=0.046, pad=0.04) | |
| fig.patch.set_facecolor('#1a1a2e') | |
| ax.set_facecolor('#16213e') | |
| ax.tick_params(colors='white') | |
| ax.xaxis.label.set_color('white') | |
| ax.yaxis.label.set_color('white') | |
| ax.title.set_color('white') | |
| return fig | |