|
|
| """Reference-based audio metrics for separation/cover evaluation."""
|
| from __future__ import annotations
|
|
|
| from typing import Mapping
|
|
|
| import numpy as np
|
|
|
|
|
| EPS = 1e-10
|
|
|
|
|
| def _as_mono_float(audio: np.ndarray) -> np.ndarray:
|
| arr = np.asarray(audio, dtype=np.float64)
|
| if arr.ndim == 2:
|
| arr = np.mean(arr, axis=1)
|
| return arr.reshape(-1)
|
|
|
|
|
| def _align_pair(reference: np.ndarray, estimate: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
| ref = _as_mono_float(reference)
|
| est = _as_mono_float(estimate)
|
| n = min(ref.size, est.size)
|
| if n <= 0:
|
| raise ValueError("Audio metric received empty reference or estimate.")
|
| return ref[:n], est[:n]
|
|
|
|
|
| def _power(audio: np.ndarray) -> float:
|
| arr = np.asarray(audio, dtype=np.float64).reshape(-1)
|
| return float(np.sum(arr * arr))
|
|
|
|
|
| def _db_ratio(signal_power: float, noise_power: float) -> float:
|
| signal_power = max(float(signal_power), EPS)
|
| noise_power = max(float(noise_power), EPS)
|
| return float(10.0 * np.log10(signal_power / noise_power))
|
|
|
|
|
| def signal_distortion_ratio(reference: np.ndarray, estimate: np.ndarray) -> float:
|
| """Scale-dependent SDR: 10 log10(||s||^2 / ||s - shat||^2)."""
|
| ref, est = _align_pair(reference, estimate)
|
| return _db_ratio(_power(ref), _power(ref - est))
|
|
|
|
|
| def scale_invariant_signal_distortion_ratio(reference: np.ndarray, estimate: np.ndarray) -> float:
|
| """SI-SDR as used by modern source-separation literature."""
|
| ref, est = _align_pair(reference, estimate)
|
| ref = ref - float(np.mean(ref))
|
| est = est - float(np.mean(est))
|
| ref_power = _power(ref)
|
| if ref_power <= EPS:
|
| raise ValueError("SI-SDR reference is silent.")
|
| scale = float(np.dot(est, ref) / (ref_power + EPS))
|
| target = scale * ref
|
| residual = est - target
|
| return _db_ratio(_power(target), _power(residual))
|
|
|
|
|
| def signal_to_noise_ratio(reference: np.ndarray, estimate: np.ndarray) -> float:
|
| """Alias for scale-dependent reconstruction SNR."""
|
| return signal_distortion_ratio(reference, estimate)
|
|
|
|
|
| def evaluate_reference_stems(
|
| references: Mapping[str, np.ndarray],
|
| estimates: Mapping[str, np.ndarray],
|
| ) -> dict:
|
| """Compute true reference-based metrics for matching stems.
|
|
|
| The caller must provide time-aligned reference stems. Without references,
|
| SI-SDR/SDR cannot be interpreted as source-separation quality.
|
| """
|
| stem_metrics: dict[str, dict[str, float]] = {}
|
| for stem_name, reference_audio in references.items():
|
| if stem_name not in estimates:
|
| raise KeyError(f"Missing estimated stem for reference: {stem_name}")
|
| estimate_audio = estimates[stem_name]
|
| stem_metrics[stem_name] = {
|
| "si_sdr": scale_invariant_signal_distortion_ratio(reference_audio, estimate_audio),
|
| "sdr": signal_distortion_ratio(reference_audio, estimate_audio),
|
| "snr": signal_to_noise_ratio(reference_audio, estimate_audio),
|
| }
|
|
|
| if not stem_metrics:
|
| raise ValueError("No reference stems were provided.")
|
|
|
| return {
|
| "mean_si_sdr": float(np.mean([metrics["si_sdr"] for metrics in stem_metrics.values()])),
|
| "mean_sdr": float(np.mean([metrics["sdr"] for metrics in stem_metrics.values()])),
|
| "mean_snr": float(np.mean([metrics["snr"] for metrics in stem_metrics.values()])),
|
| "stems": stem_metrics,
|
| }
|
|
|