AI-RVC / lib /audio_metrics.py
mason369's picture
Release v1.2.1
a9536c4 verified
# -*- coding: utf-8 -*-
"""Reference-based audio metrics for separation/cover evaluation."""
from __future__ import annotations
from typing import Mapping
import numpy as np
EPS = 1e-10
def _as_mono_float(audio: np.ndarray) -> np.ndarray:
arr = np.asarray(audio, dtype=np.float64)
if arr.ndim == 2:
arr = np.mean(arr, axis=1)
return arr.reshape(-1)
def _align_pair(reference: np.ndarray, estimate: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
ref = _as_mono_float(reference)
est = _as_mono_float(estimate)
n = min(ref.size, est.size)
if n <= 0:
raise ValueError("Audio metric received empty reference or estimate.")
return ref[:n], est[:n]
def _power(audio: np.ndarray) -> float:
arr = np.asarray(audio, dtype=np.float64).reshape(-1)
return float(np.sum(arr * arr))
def _db_ratio(signal_power: float, noise_power: float) -> float:
signal_power = max(float(signal_power), EPS)
noise_power = max(float(noise_power), EPS)
return float(10.0 * np.log10(signal_power / noise_power))
def signal_distortion_ratio(reference: np.ndarray, estimate: np.ndarray) -> float:
"""Scale-dependent SDR: 10 log10(||s||^2 / ||s - shat||^2)."""
ref, est = _align_pair(reference, estimate)
return _db_ratio(_power(ref), _power(ref - est))
def scale_invariant_signal_distortion_ratio(reference: np.ndarray, estimate: np.ndarray) -> float:
"""SI-SDR as used by modern source-separation literature."""
ref, est = _align_pair(reference, estimate)
ref = ref - float(np.mean(ref))
est = est - float(np.mean(est))
ref_power = _power(ref)
if ref_power <= EPS:
raise ValueError("SI-SDR reference is silent.")
scale = float(np.dot(est, ref) / (ref_power + EPS))
target = scale * ref
residual = est - target
return _db_ratio(_power(target), _power(residual))
def signal_to_noise_ratio(reference: np.ndarray, estimate: np.ndarray) -> float:
"""Alias for scale-dependent reconstruction SNR."""
return signal_distortion_ratio(reference, estimate)
def evaluate_reference_stems(
references: Mapping[str, np.ndarray],
estimates: Mapping[str, np.ndarray],
) -> dict:
"""Compute true reference-based metrics for matching stems.
The caller must provide time-aligned reference stems. Without references,
SI-SDR/SDR cannot be interpreted as source-separation quality.
"""
stem_metrics: dict[str, dict[str, float]] = {}
for stem_name, reference_audio in references.items():
if stem_name not in estimates:
raise KeyError(f"Missing estimated stem for reference: {stem_name}")
estimate_audio = estimates[stem_name]
stem_metrics[stem_name] = {
"si_sdr": scale_invariant_signal_distortion_ratio(reference_audio, estimate_audio),
"sdr": signal_distortion_ratio(reference_audio, estimate_audio),
"snr": signal_to_noise_ratio(reference_audio, estimate_audio),
}
if not stem_metrics:
raise ValueError("No reference stems were provided.")
return {
"mean_si_sdr": float(np.mean([metrics["si_sdr"] for metrics in stem_metrics.values()])),
"mean_sdr": float(np.mean([metrics["sdr"] for metrics in stem_metrics.values()])),
"mean_snr": float(np.mean([metrics["snr"] for metrics in stem_metrics.values()])),
"stems": stem_metrics,
}