Spaces:
Runtime error
Runtime error
File size: 4,703 Bytes
742e266 0020ddc 742e266 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | # Created: 2026-02-18
# Purpose: Original/residual mel-spectrogram visualization (matplotlib)
# Dependencies: matplotlib, numpy, torch
"""Mel-spectrogram comparison visualization of original audio and analysis results."""
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from config import SR, N_FFT, HOP_LENGTH
N_MELS = 128
def _compute_mel_spectrogram(audio_1d: np.ndarray) -> np.ndarray:
"""1D audio -> mel spectrogram (dB scale)."""
from scipy import signal as sig
# STFT
_, _, Zxx = sig.stft(audio_1d, fs=SR, window='hann',
nperseg=N_FFT, noverlap=N_FFT - HOP_LENGTH)
mag = np.abs(Zxx)
# Mel filterbank
n_freqs = N_FFT // 2 + 1
def hz_to_mel(f): return 2595.0 * np.log10(1.0 + f / 700.0)
def mel_to_hz(m): return 700.0 * (10.0 ** (m / 2595.0) - 1.0)
mel_pts = np.linspace(hz_to_mel(0), hz_to_mel(SR / 2), N_MELS + 2)
hz_pts = mel_to_hz(mel_pts)
freqs = np.linspace(0, SR / 2, n_freqs)
fb = np.zeros((n_freqs, N_MELS), dtype=np.float32)
for i in range(N_MELS):
lo, mid, hi = hz_pts[i], hz_pts[i + 1], hz_pts[i + 2]
for j in range(n_freqs):
if lo <= freqs[j] <= mid and (mid - lo) > 0:
fb[j, i] = (freqs[j] - lo) / (mid - lo)
elif mid < freqs[j] <= hi and (hi - mid) > 0:
fb[j, i] = (hi - freqs[j]) / (hi - mid)
mel = fb.T @ (mag ** 2)
mel_db = 10.0 * np.log10(np.maximum(mel, 1e-10))
max_val = np.max(mel_db)
mel_db = np.maximum(mel_db, max_val - 80.0)
return mel_db
def plot_spectrograms(original_mono: np.ndarray,
residual_mono: np.ndarray = None) -> plt.Figure:
"""Return mel-spectrogram figure (1-panel or 2-panel).
Args:
original_mono: 1D numpy array (mono original)
residual_mono: 1D numpy array (Demucs residual), optional
Returns:
matplotlib Figure
"""
max_samples = 30 * SR
orig = original_mono[:max_samples]
mel_orig = _compute_mel_spectrogram(orig)
if residual_mono is not None:
# 2-panel: Original vs Residual
res = residual_mono[:min(len(residual_mono), max_samples)]
mel_res = _compute_mel_spectrogram(res)
fig, axes = plt.subplots(1, 2, figsize=(14, 4), constrained_layout=True)
t_orig = np.linspace(0, len(orig) / SR, mel_orig.shape[1])
t_res = np.linspace(0, len(res) / SR, mel_res.shape[1])
im0 = axes[0].imshow(mel_orig, aspect='auto', origin='lower',
extent=[0, t_orig[-1], 0, SR / 2000],
cmap='magma', interpolation='bilinear')
axes[0].set_title('Original', fontsize=12, fontweight='bold')
axes[0].set_xlabel('Time (s)')
axes[0].set_ylabel('Frequency (kHz)')
axes[0].set_ylim(0, 16)
plt.colorbar(im0, ax=axes[0], label='dB', fraction=0.046, pad=0.04)
im1 = axes[1].imshow(mel_res, aspect='auto', origin='lower',
extent=[0, t_res[-1], 0, SR / 2000],
cmap='magma', interpolation='bilinear')
axes[1].set_title('Demucs Residual', fontsize=12, fontweight='bold')
axes[1].set_xlabel('Time (s)')
axes[1].set_ylabel('Frequency (kHz)')
axes[1].set_ylim(0, 16)
plt.colorbar(im1, ax=axes[1], label='dB', fraction=0.046, pad=0.04)
fig.patch.set_facecolor('#1a1a2e')
for ax in axes:
ax.set_facecolor('#16213e')
ax.tick_params(colors='white')
ax.xaxis.label.set_color('white')
ax.yaxis.label.set_color('white')
ax.title.set_color('white')
else:
# 1-panel: Original only
fig, ax = plt.subplots(1, 1, figsize=(14, 4), constrained_layout=True)
t_orig = np.linspace(0, len(orig) / SR, mel_orig.shape[1])
im0 = ax.imshow(mel_orig, aspect='auto', origin='lower',
extent=[0, t_orig[-1], 0, SR / 2000],
cmap='magma', interpolation='bilinear')
ax.set_title('Mel Spectrogram', fontsize=12, fontweight='bold')
ax.set_xlabel('Time (s)')
ax.set_ylabel('Frequency (kHz)')
ax.set_ylim(0, 16)
plt.colorbar(im0, ax=ax, label='dB', fraction=0.046, pad=0.04)
fig.patch.set_facecolor('#1a1a2e')
ax.set_facecolor('#16213e')
ax.tick_params(colors='white')
ax.xaxis.label.set_color('white')
ax.yaxis.label.set_color('white')
ax.title.set_color('white')
return fig
|