#!/usr/bin/env python # -*- coding: utf-8 -*- """Evaluate local karaoke lead/backing split candidates with reproducible metrics.""" from __future__ import annotations import argparse import json import sys from pathlib import Path from typing import Dict, Iterable, List import numpy as np import soundfile as sf REPO_ROOT = Path(__file__).resolve().parent.parent if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) from infer.separator import KARAOKE_DEFAULT_MODEL, KaraokeSeparator from lib.audio_metrics import evaluate_reference_stems KARAOKE_MIN_LENGTH_COVERAGE = 0.999 def _load_mono(path: Path) -> tuple[np.ndarray, int]: audio, sr = sf.read(str(path), always_2d=True) mono = np.asarray(audio, dtype=np.float32).mean(axis=1) return mono.astype(np.float32), int(sr) def _rms(audio: np.ndarray) -> float: audio = np.asarray(audio, dtype=np.float32).reshape(-1) if audio.size == 0: return 0.0 return float(np.sqrt(np.mean(np.square(audio), dtype=np.float64) + 1e-12)) def _abs_corr(a: np.ndarray, b: np.ndarray) -> float: a = np.asarray(a, dtype=np.float32).reshape(-1) b = np.asarray(b, dtype=np.float32).reshape(-1) n = min(a.size, b.size) if n < 8: return 0.0 a = a[:n] b = b[:n] if float(np.std(a)) <= 1e-8 or float(np.std(b)) <= 1e-8: return 0.0 corr = np.corrcoef(a, b)[0, 1] if not np.isfinite(corr): return 0.0 return float(abs(corr)) def score_karaoke_stems( input_vocals_path: Path, lead_path: Path, backing_path: Path, ) -> Dict[str, float]: """Score a karaoke split without ground-truth stems. This is a proxy, not SDR: good candidates reconstruct the input when summed, keep lead/backing decorrelated, and preserve a plausible backing bed instead of collapsing everything into either stem. """ input_audio, input_sr = _load_mono(Path(input_vocals_path)) lead_audio, lead_sr = _load_mono(Path(lead_path)) backing_audio, backing_sr = _load_mono(Path(backing_path)) if lead_sr != input_sr or backing_sr != input_sr: raise ValueError("Karaoke scoring expects matching sample rates.") input_len = input_audio.size aligned_len = min(input_audio.size, lead_audio.size, backing_audio.size) if aligned_len <= 0: raise ValueError("Karaoke scoring received empty audio.") length_coverage = float(aligned_len / max(1, input_len)) input_audio = input_audio[:aligned_len] lead_audio = lead_audio[:aligned_len] backing_audio = backing_audio[:aligned_len] input_rms = _rms(input_audio) lead_rms = _rms(lead_audio) backing_rms = _rms(backing_audio) reconstruction_error = _rms(input_audio - lead_audio - backing_audio) / (input_rms + 1e-12) lead_backing_abs_corr = _abs_corr(lead_audio, backing_audio) lead_input_abs_corr = _abs_corr(lead_audio, input_audio) lead_ratio = lead_rms / (input_rms + 1e-12) backing_ratio = backing_rms / (input_rms + 1e-12) backing_target = 0.24 backing_balance_penalty = abs(np.log2(max(float(backing_ratio), 1e-4) / backing_target)) lead_body_penalty = max(0.0, 0.70 - float(lead_ratio)) + max(0.0, float(lead_ratio) - 1.15) length_penalty = max(0.0, 1.0 - float(length_coverage)) score = float( 100.0 - 46.0 * float(reconstruction_error) - 30.0 * float(lead_backing_abs_corr) - 8.0 * float(backing_balance_penalty) - 12.0 * float(lead_body_penalty) - 200.0 * float(length_penalty) + 3.0 * float(lead_input_abs_corr) ) return { "score": score, "input_rms": float(input_rms), "lead_rms": float(lead_rms), "backing_rms": float(backing_rms), "lead_ratio": float(lead_ratio), "backing_ratio": float(backing_ratio), "length_coverage": float(length_coverage), "length_penalty": float(length_penalty), "reconstruction_error": float(reconstruction_error), "lead_backing_abs_corr": float(lead_backing_abs_corr), "lead_input_abs_corr": float(lead_input_abs_corr), } def score_reference_stems( reference_lead_path: Path, reference_backing_path: Path, lead_path: Path, backing_path: Path, ) -> Dict[str, object]: """Compute true reference-based SI-SDR/SDR when reference stems exist.""" reference_lead, reference_lead_sr = _load_mono(Path(reference_lead_path)) reference_backing, reference_backing_sr = _load_mono(Path(reference_backing_path)) lead_audio, lead_sr = _load_mono(Path(lead_path)) backing_audio, backing_sr = _load_mono(Path(backing_path)) if len({reference_lead_sr, reference_backing_sr, lead_sr, backing_sr}) != 1: raise ValueError("Reference scoring expects matching sample rates.") return evaluate_reference_stems( references={"lead": reference_lead, "backing": reference_backing}, estimates={"lead": lead_audio, "backing": backing_audio}, ) def _unique(items: Iterable[str]) -> List[str]: result: List[str] = [] for item in items: if item and item not in result: result.append(item) return result def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--vocals-path", required=True, help="Separated vocals.wav to split into lead/backing.") parser.add_argument("--output-dir", required=True, help="Directory for candidate stems and reports.") parser.add_argument( "--models", nargs="*", default=None, help="Karaoke model filenames. Defaults to the current strict SOTA default.", ) parser.add_argument("--reference-lead", default=None, help="Optional ground-truth/reference lead stem.") parser.add_argument("--reference-backing", default=None, help="Optional ground-truth/reference backing stem.") parser.add_argument("--device", default="cuda") return parser.parse_args() def _resolve_existing_path(path_value: str | None, label: str) -> Path | None: if not path_value: return None path = Path(path_value) if not path.is_absolute(): path = (REPO_ROOT / path).resolve() if not path.exists(): raise FileNotFoundError(f"{label} not found: {path}") return path def _result_sort_key(item: dict) -> tuple: reference_metrics = item.get("reference_metrics") if reference_metrics: return (1, float(reference_metrics["mean_si_sdr"])) metrics = item["metrics"] return ( 0, float(metrics["length_coverage"]) >= KARAOKE_MIN_LENGTH_COVERAGE, float(metrics["score"]), ) def main() -> int: args = parse_args() vocals_path = Path(args.vocals_path) if not vocals_path.is_absolute(): vocals_path = (REPO_ROOT / vocals_path).resolve() if not vocals_path.exists(): raise FileNotFoundError(f"Vocals path not found: {vocals_path}") reference_lead_path = _resolve_existing_path(args.reference_lead, "Reference lead") reference_backing_path = _resolve_existing_path(args.reference_backing, "Reference backing") has_references = reference_lead_path is not None or reference_backing_path is not None if has_references and not (reference_lead_path and reference_backing_path): raise ValueError("--reference-lead and --reference-backing must be provided together.") output_dir = Path(args.output_dir) if not output_dir.is_absolute(): output_dir = (REPO_ROOT / output_dir).resolve() output_dir.mkdir(parents=True, exist_ok=True) models = _unique(args.models or [KARAOKE_DEFAULT_MODEL]) results = [] for model_name in models: candidate_dir = output_dir / Path(model_name).stem candidate_dir.mkdir(parents=True, exist_ok=True) separator = KaraokeSeparator(model_filename=model_name, device=args.device) try: lead_path, backing_path = separator.separate(str(vocals_path), str(candidate_dir)) metrics = score_karaoke_stems(vocals_path, Path(lead_path), Path(backing_path)) reference_metrics = None if reference_lead_path and reference_backing_path: reference_metrics = score_reference_stems( reference_lead_path=reference_lead_path, reference_backing_path=reference_backing_path, lead_path=Path(lead_path), backing_path=Path(backing_path), ) results.append( { "model": model_name, "lead_path": str(lead_path), "backing_path": str(backing_path), "metrics": metrics, "reference_metrics": reference_metrics, } ) if reference_metrics: print( f"{model_name}: mean_si_sdr={reference_metrics['mean_si_sdr']:.3f}, " f"diagnostic_score={metrics['score']:.3f}" ) else: print(f"{model_name}: score={metrics['score']:.3f}, backing_ratio={metrics['backing_ratio']:.3f}") finally: separator.unload_model() results.sort(key=_result_sort_key, reverse=True) summary = { "vocals_path": str(vocals_path), "output_dir": str(output_dir), "reference_lead_path": str(reference_lead_path) if reference_lead_path else None, "reference_backing_path": str(reference_backing_path) if reference_backing_path else None, "ranking": [ { "rank": index + 1, "model": item["model"], **item["metrics"], **( { "reference_mean_si_sdr": item["reference_metrics"]["mean_si_sdr"], "reference_mean_sdr": item["reference_metrics"]["mean_sdr"], } if item.get("reference_metrics") else {} ), } for index, item in enumerate(results) ], "results": results, "score_note": ( "SI-SDR/SDR are only computed when reference stems are provided. " "Without references, score is a diagnostic proxy for reconstruction, " "decorrelation, plausible backing level, lead/input coherence, and full-length coverage." ), } summary_path = output_dir / "karaoke_model_report.json" summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8") markdown_lines = [ "# Karaoke Model Report", "", f"- vocals: `{vocals_path}`", f"- output: `{output_dir}`", "", ] if has_references: markdown_lines.extend( [ f"- reference lead: `{reference_lead_path}`", f"- reference backing: `{reference_backing_path}`", "", "| rank | model | mean_si_sdr | mean_sdr | diag_score | len | recon_err | corr |", "|---:|---|---:|---:|---:|---:|---:|---:|", ] ) else: markdown_lines.extend( [ "| rank | model | score | len | recon_err | corr | lead_in_corr | lead_ratio | backing_ratio |", "|---:|---|---:|---:|---:|---:|---:|---:|---:|", ] ) for index, item in enumerate(results, start=1): metrics = item["metrics"] reference_metrics = item.get("reference_metrics") if reference_metrics: markdown_lines.append( f"| {index} | `{item['model']}` | {reference_metrics['mean_si_sdr']:.3f} | " f"{reference_metrics['mean_sdr']:.3f} | {metrics['score']:.3f} | " f"{metrics['length_coverage']:.4f} | {metrics['reconstruction_error']:.4f} | " f"{metrics['lead_backing_abs_corr']:.4f} |" ) else: markdown_lines.append( f"| {index} | `{item['model']}` | {metrics['score']:.3f} | " f"{metrics['length_coverage']:.4f} | " f"{metrics['reconstruction_error']:.4f} | {metrics['lead_backing_abs_corr']:.4f} | " f"{metrics['lead_input_abs_corr']:.4f} | " f"{metrics['lead_ratio']:.3f} | {metrics['backing_ratio']:.3f} |" ) markdown_path = output_dir / "karaoke_model_report.md" markdown_path.write_text("\n".join(markdown_lines) + "\n", encoding="utf-8") print(f"Summary written to: {summary_path}") print(f"Report written to: {markdown_path}") return 0 if __name__ == "__main__": raise SystemExit(main())