cccode / eval_audio.py
WayneW's picture
Upload folder using huggingface_hub
705a8fd verified
raw
history blame
6 kB
# eval_audio.py
from typing import Optional
import os
import re
import argparse
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio
import librosa
import matplotlib.pyplot as plt
_EPS = 1e-12
def build_mel_transform(
sample_rate,
n_fft=1024,
win_length=1024,
hop_length=256,
n_mels=80,
power=1.0,
f_min=0.0,
f_max=None,
mel_scale="htk",
norm=None,
device=None,
):
mel_tf = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
f_min=f_min,
f_max=f_max,
n_mels=n_mels,
power=power,
center=True,
norm=norm,
mel_scale=mel_scale,
)
if device is not None:
mel_tf = mel_tf.to(device)
return mel_tf
def _ensure_stereo_torch(x):
if x.dim() == 1:
x = x.unsqueeze(0)
if x.size(0) == 1:
x = x.repeat(2, 1)
elif x.size(0) > 2:
x = x[:2]
return x
@torch.no_grad()
def mel_cosine_stereo(
ref, hat, sample_rate,
n_fft=1024,
win_length=1024,
hop_length=256,
n_mels=80,
power=1.0,
mel_tf=None,
):
ref = _ensure_stereo_torch(ref)
hat = _ensure_stereo_torch(hat)
device = ref.device
if mel_tf is None:
mel_tf = build_mel_transform(
sample_rate=sample_rate,
n_fft=n_fft, win_length=win_length, hop_length=hop_length,
n_mels=n_mels, power=power, device=device
)
else:
mel_tf = mel_tf.to(device)
Mr = mel_tf(ref)
Mh = mel_tf(hat)
Ar = Mr.reshape(Mr.size(0), -1)
Ah = Mh.reshape(Mh.size(0), -1)
sim = F.cosine_similarity(Ar, Ah, dim=-1)
return float(sim.mean().item())
@torch.no_grad()
def drms_avg_db_stereo(ref, hat, win_length=1024, hop_length=256):
ref = _ensure_stereo_torch(ref)
hat = _ensure_stereo_torch(hat)
def _rms_db(x):
C, T = x.size(0), x.size(1)
if T < win_length:
x = F.pad(x, (0, win_length - T))
frames = x.unfold(dimension=-1, size=win_length, step=hop_length)
rms = torch.sqrt(frames.pow(2).mean(dim=-1) + _EPS)
db = 20.0 * torch.log10(rms + _EPS)
return db
dbr = _rms_db(ref)
dbh = _rms_db(hat)
Fmin = min(dbr.size(-1), dbh.size(-1))
dbr = dbr[:, :Fmin]
dbh = dbh[:, :Fmin]
d_db = dbh - dbr
return float(d_db.mean(dim=-1).mean().item())
def load_stereo_wav_np(path):
y, sr = librosa.load(path, sr=None, mono=False)
if y.ndim == 1:
y = np.stack([y, y], axis=0)
elif y.shape[0] != 2:
y = y[:2]
return y, sr
def compute_spectrogram_np(audio_stereo,
n_fft=512,
hop_length=160,
win_length=400,
pool=4):
def _stft_abs(sig):
st = np.abs(librosa.stft(sig, n_fft=n_fft, hop_length=hop_length, win_length=win_length))
h, w = st.shape
hq, wq = h // pool, w // pool
if hq == 0 or wq == 0:
raise ValueError(f"audio too short for pooling (stft shape {st.shape})")
st = st[:hq * pool, :wq * pool]
st = st.reshape(hq, pool, wq, pool).mean(axis=(1, 3))
return st
L = np.log1p(_stft_abs(audio_stereo[0]))
if audio_stereo.shape[0] >= 2:
R = np.log1p(_stft_abs(audio_stereo[1]))
else:
R = L.copy()
spec = np.stack([L, R], axis=-1)
return spec
def render_ref_hat_panel(title, spec_ref, spec_hat, out_path, cmap="magma"):
L_all = [spec_ref[:, :, 0], spec_hat[:, :, 0]]
R_all = [spec_ref[:, :, 1], spec_hat[:, :, 1]]
if any(a.size == 0 for a in L_all + R_all):
print(f"[SKIP]")
return False
vmin_L = min(a.min() for a in L_all)
vmax_L = max(a.max() for a in L_all)
vmin_R = min(a.min() for a in R_all)
vmax_R = max(a.max() for a in R_all)
fig, axes = plt.subplots(2, 2, figsize=(8, 6), constrained_layout=True)
Lr, Rr = spec_ref[:, :, 0], spec_ref[:, :, 1]
Lh, Rh = spec_hat[:, :, 0], spec_hat[:, :, 1]
axes[0, 0].imshow(Lr, origin="lower", aspect="auto", cmap=cmap, vmin=vmin_L, vmax=vmax_L)
axes[0, 1].imshow(Lh, origin="lower", aspect="auto", cmap=cmap, vmin=vmin_L, vmax=vmax_L)
axes[1, 0].imshow(Rr, origin="lower", aspect="auto", cmap=cmap, vmin=vmin_R, vmax=vmax_R)
axes[1, 1].imshow(Rh, origin="lower", aspect="auto", cmap=cmap, vmin=vmin_R, vmax=vmax_R)
axes[0, 0].set_title("ref")
axes[0, 1].set_title("hat")
axes[0, 0].set_ylabel("Left")
axes[1, 0].set_ylabel("Right")
for ax in axes.ravel():
ax.set_xticks([])
ax.set_yticks([])
fig.suptitle(title)
os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
plt.savefig(out_path, dpi=180)
plt.close(fig)
return True
def save_ref_hat_spectrogram_panel(
ref, hat, out_path,
n_fft=512,
hop_length=160,
win_length=400,
pool=4,
title="ref vs hat (binaural spectrogram)",
cmap="magma",
):
def _to_np_stereo(x):
if isinstance(x, torch.Tensor):
x = x.detach().to(torch.float32).cpu().numpy()
if x.ndim == 1:
x = np.stack([x, x], axis=0)
elif x.shape[0] == 1:
x = np.repeat(x, 2, axis=0)
elif x.shape[0] > 2:
x = x[:2]
return x
ref_np = _to_np_stereo(ref)
hat_np = _to_np_stereo(hat)
spec_ref = compute_spectrogram_np(ref_np, n_fft=n_fft, hop_length=hop_length, win_length=win_length, pool=pool)
spec_hat = compute_spectrogram_np(hat_np, n_fft=n_fft, hop_length=hop_length, win_length=win_length, pool=pool)
return render_ref_hat_panel(title, spec_ref, spec_hat, out_path, cmap=cmap)