AI-RVC / tests /run_mode_matrix.py
mason369's picture
Release v1.2.1
a9536c4 verified
# -*- coding: utf-8 -*-
"""Run real AI cover processing modes and write a JSON report."""
from __future__ import annotations
import argparse
import json
import platform
import sys
import time
import traceback
from pathlib import Path
from typing import Any, Dict, Iterable, Optional
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
CASES: tuple[dict[str, Any], ...] = (
dict(name="current_roformer_auto_source_auto", separator="roformer", vc_preprocess_mode="auto", source_constraint_mode="auto", vc_pipeline_mode="current", karaoke_separation=False, singing_repair=False),
dict(name="current_roformer_uvr_deecho_source_on", separator="roformer", vc_preprocess_mode="uvr_deecho", source_constraint_mode="on", vc_pipeline_mode="current", karaoke_separation=False, singing_repair=False),
dict(name="current_roformer_karaoke_source_off", separator="roformer", vc_preprocess_mode="auto", source_constraint_mode="off", vc_pipeline_mode="current", karaoke_separation=True, singing_repair=False),
dict(name="current_demucs_auto_source_auto", separator="demucs", vc_preprocess_mode="auto", source_constraint_mode="auto", vc_pipeline_mode="current", karaoke_separation=False, singing_repair=False),
dict(name="current_uvr5_auto_source_auto", separator="uvr5", vc_preprocess_mode="auto", source_constraint_mode="auto", vc_pipeline_mode="current", karaoke_separation=False, singing_repair=False),
dict(name="official_uvr5_one_to_one", separator="uvr5", vc_preprocess_mode="auto", source_constraint_mode="auto", vc_pipeline_mode="official", karaoke_separation=False, singing_repair=False),
dict(name="official_uvr5_singing_repair", separator="uvr5", vc_preprocess_mode="auto", source_constraint_mode="auto", vc_pipeline_mode="official", karaoke_separation=False, singing_repair=True),
)
def _require_file(path: Path, label: str) -> Path:
resolved = path.expanduser().resolve()
if not resolved.exists():
raise FileNotFoundError(f"{label} does not exist: {resolved}")
if resolved.stat().st_size <= 0:
raise RuntimeError(f"{label} is empty: {resolved}")
return resolved
def _default_index(model_path: Path) -> Optional[Path]:
direct = model_path.with_suffix(".index")
if direct.exists():
return direct
indexes = sorted(model_path.parent.glob("*.index"))
return indexes[0] if len(indexes) == 1 else None
def _device(require_cuda: bool) -> str:
import torch
if torch.cuda.is_available():
return "cuda"
if require_cuda:
raise RuntimeError("CUDA is required, but torch.cuda.is_available() is False.")
return "cpu"
def _validate_outputs(result: Dict[str, str], keys: Iterable[str]) -> Dict[str, Dict[str, Any]]:
files: Dict[str, Dict[str, Any]] = {}
for key in keys:
value = result.get(key)
if not value:
raise RuntimeError(f"Pipeline result missing key: {key}")
path = Path(value)
if not path.exists():
raise FileNotFoundError(f"Pipeline output for {key} does not exist: {path}")
size = path.stat().st_size
if size <= 0:
raise RuntimeError(f"Pipeline output for {key} is empty: {path}")
files[key] = {"path": str(path), "bytes": size}
return files
def _run_case(case: Dict[str, Any], input_audio: Path, model_path: Path, index_path: Optional[Path], output_root: Path, device: str) -> Dict[str, Any]:
from infer.cover_pipeline import CoverPipeline
started = time.time()
case_output = output_root / case["name"]
case_output.mkdir(parents=True, exist_ok=True)
pipeline = CoverPipeline(device=device)
try:
result = pipeline.process(
input_audio=str(input_audio),
model_path=str(model_path),
index_path=str(index_path) if index_path else None,
pitch_shift=0,
index_ratio=0.5,
filter_radius=3,
rms_mix_rate=0.0,
protect=0.33,
speaker_id=0,
f0_method="rmvpe",
demucs_model="htdemucs_ft",
demucs_shifts=1,
demucs_overlap=0.25,
demucs_split=True,
roformer_model="ensemble:vocal_rvc",
separator=case["separator"],
uvr5_model="HP2_all_vocals",
uvr5_agg=10,
uvr5_format="wav",
use_official=True,
hubert_layer=12,
vocals_volume=1.0,
accompaniment_volume=1.0,
reverb_amount=0.0,
backing_mix=0.0,
karaoke_separation=bool(case["karaoke_separation"]),
karaoke_model="ensemble:karaoke",
karaoke_merge_backing_into_accompaniment=True,
vc_preprocess_mode=case["vc_preprocess_mode"],
source_constraint_mode=case["source_constraint_mode"],
vc_pipeline_mode=case["vc_pipeline_mode"],
singing_repair=bool(case["singing_repair"]),
output_dir=str(case_output),
model_display_name="matrix",
)
keys = ["cover", "vocals", "converted_vocals", "accompaniment", "all_files_dir"]
if case["karaoke_separation"]:
keys.extend(["lead_vocals", "backing_vocals"])
return {
"name": case["name"],
"status": "pass",
"seconds": round(time.time() - started, 3),
"parameters": case,
"result": result,
"files": _validate_outputs(result, keys),
}
except Exception as exc:
return {
"name": case["name"],
"status": "fail",
"seconds": round(time.time() - started, 3),
"parameters": case,
"error": str(exc),
"traceback": traceback.format_exc(),
}
finally:
if pipeline.separator is not None:
pipeline.separator.unload_model()
if pipeline.karaoke_separator is not None:
pipeline.karaoke_separator.unload_model()
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument("--input", required=True)
parser.add_argument("--model-path", default=str(REPO_ROOT / "assets" / "weights" / "characters" / "rin" / "Rin.pth"))
parser.add_argument("--index-path")
parser.add_argument("--output-dir", default=str(REPO_ROOT / "outputs" / "mode_matrix"))
parser.add_argument("--require-cuda", action="store_true")
parser.add_argument("--case", action="append", choices=[case["name"] for case in CASES])
args = parser.parse_args()
input_audio = _require_file(Path(args.input), "Input audio")
model_path = _require_file(Path(args.model_path), "RVC model")
index_path = _require_file(Path(args.index_path), "RVC index") if args.index_path else _default_index(model_path)
device = _device(args.require_cuda)
stamp = time.strftime("%Y%m%d_%H%M%S")
output_root = Path(args.output_dir).expanduser().resolve() / f"{platform.system().lower()}_{stamp}"
output_root.mkdir(parents=True, exist_ok=True)
selected = set(args.case or [])
cases = [case for case in CASES if not selected or case["name"] in selected]
report = {
"platform": {"system": platform.system(), "release": platform.release(), "python": platform.python_version()},
"device": device,
"input_audio": str(input_audio),
"model_path": str(model_path),
"index_path": str(index_path) if index_path else None,
"output_root": str(output_root),
"cases": [],
}
report_path = output_root / "mode_matrix_results.json"
for case in cases:
print(f"\n=== {case['name']} ===", flush=True)
record = _run_case(case, input_audio, model_path, index_path, output_root, device)
print(f"{case['name']}: {record['status']} ({record['seconds']}s)", flush=True)
if record["status"] == "fail":
print(record["error"], flush=True)
report["cases"].append(record)
report_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
failed = [case["name"] for case in report["cases"] if case["status"] != "pass"]
report["summary"] = {"total": len(cases), "passed": len(cases) - len(failed), "failed": len(failed), "failed_cases": failed}
report_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"\nReport: {report_path}", flush=True)
print(json.dumps(report["summary"], ensure_ascii=False, indent=2), flush=True)
return 1 if failed else 0
if __name__ == "__main__":
raise SystemExit(main())