artifactnet / visualization /spectrogram.py
intrect's picture
feat(space): CPU ONNX runtime build (v9.4, full-song sliding aggregation)
0020ddc
raw
history blame
4.7 kB
# Created: 2026-02-18
# Purpose: Original/residual mel-spectrogram visualization (matplotlib)
# Dependencies: matplotlib, numpy, torch
"""Mel-spectrogram comparison visualization of original audio and analysis results."""
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from config import SR, N_FFT, HOP_LENGTH
N_MELS = 128
def _compute_mel_spectrogram(audio_1d: np.ndarray) -> np.ndarray:
"""1D audio -> mel spectrogram (dB scale)."""
from scipy import signal as sig
# STFT
_, _, Zxx = sig.stft(audio_1d, fs=SR, window='hann',
nperseg=N_FFT, noverlap=N_FFT - HOP_LENGTH)
mag = np.abs(Zxx)
# Mel filterbank
n_freqs = N_FFT // 2 + 1
def hz_to_mel(f): return 2595.0 * np.log10(1.0 + f / 700.0)
def mel_to_hz(m): return 700.0 * (10.0 ** (m / 2595.0) - 1.0)
mel_pts = np.linspace(hz_to_mel(0), hz_to_mel(SR / 2), N_MELS + 2)
hz_pts = mel_to_hz(mel_pts)
freqs = np.linspace(0, SR / 2, n_freqs)
fb = np.zeros((n_freqs, N_MELS), dtype=np.float32)
for i in range(N_MELS):
lo, mid, hi = hz_pts[i], hz_pts[i + 1], hz_pts[i + 2]
for j in range(n_freqs):
if lo <= freqs[j] <= mid and (mid - lo) > 0:
fb[j, i] = (freqs[j] - lo) / (mid - lo)
elif mid < freqs[j] <= hi and (hi - mid) > 0:
fb[j, i] = (hi - freqs[j]) / (hi - mid)
mel = fb.T @ (mag ** 2)
mel_db = 10.0 * np.log10(np.maximum(mel, 1e-10))
max_val = np.max(mel_db)
mel_db = np.maximum(mel_db, max_val - 80.0)
return mel_db
def plot_spectrograms(original_mono: np.ndarray,
residual_mono: np.ndarray = None) -> plt.Figure:
"""Return mel-spectrogram figure (1-panel or 2-panel).
Args:
original_mono: 1D numpy array (mono original)
residual_mono: 1D numpy array (Demucs residual), optional
Returns:
matplotlib Figure
"""
max_samples = 30 * SR
orig = original_mono[:max_samples]
mel_orig = _compute_mel_spectrogram(orig)
if residual_mono is not None:
# 2-panel: Original vs Residual
res = residual_mono[:min(len(residual_mono), max_samples)]
mel_res = _compute_mel_spectrogram(res)
fig, axes = plt.subplots(1, 2, figsize=(14, 4), constrained_layout=True)
t_orig = np.linspace(0, len(orig) / SR, mel_orig.shape[1])
t_res = np.linspace(0, len(res) / SR, mel_res.shape[1])
im0 = axes[0].imshow(mel_orig, aspect='auto', origin='lower',
extent=[0, t_orig[-1], 0, SR / 2000],
cmap='magma', interpolation='bilinear')
axes[0].set_title('Original', fontsize=12, fontweight='bold')
axes[0].set_xlabel('Time (s)')
axes[0].set_ylabel('Frequency (kHz)')
axes[0].set_ylim(0, 16)
plt.colorbar(im0, ax=axes[0], label='dB', fraction=0.046, pad=0.04)
im1 = axes[1].imshow(mel_res, aspect='auto', origin='lower',
extent=[0, t_res[-1], 0, SR / 2000],
cmap='magma', interpolation='bilinear')
axes[1].set_title('Demucs Residual', fontsize=12, fontweight='bold')
axes[1].set_xlabel('Time (s)')
axes[1].set_ylabel('Frequency (kHz)')
axes[1].set_ylim(0, 16)
plt.colorbar(im1, ax=axes[1], label='dB', fraction=0.046, pad=0.04)
fig.patch.set_facecolor('#1a1a2e')
for ax in axes:
ax.set_facecolor('#16213e')
ax.tick_params(colors='white')
ax.xaxis.label.set_color('white')
ax.yaxis.label.set_color('white')
ax.title.set_color('white')
else:
# 1-panel: Original only
fig, ax = plt.subplots(1, 1, figsize=(14, 4), constrained_layout=True)
t_orig = np.linspace(0, len(orig) / SR, mel_orig.shape[1])
im0 = ax.imshow(mel_orig, aspect='auto', origin='lower',
extent=[0, t_orig[-1], 0, SR / 2000],
cmap='magma', interpolation='bilinear')
ax.set_title('Mel Spectrogram', fontsize=12, fontweight='bold')
ax.set_xlabel('Time (s)')
ax.set_ylabel('Frequency (kHz)')
ax.set_ylim(0, 16)
plt.colorbar(im0, ax=ax, label='dB', fraction=0.046, pad=0.04)
fig.patch.set_facecolor('#1a1a2e')
ax.set_facecolor('#16213e')
ax.tick_params(colors='white')
ax.xaxis.label.set_color('white')
ax.yaxis.label.set_color('white')
ax.title.set_color('white')
return fig