Spaces:
Sleeping
Sleeping
feat: add full training and evaluation pipeline for AURIS
Browse files- app/training/run_full_pipeline.py +177 -0
- app/training/visualize_results.py +632 -0
app/training/run_full_pipeline.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AURIS Full Training & Evaluation Pipeline.
|
| 3 |
+
|
| 4 |
+
Orchestrates the complete ML pipeline end-to-end:
|
| 5 |
+
|
| 6 |
+
1. Feature extraction (49 audio features + 14 vocal features)
|
| 7 |
+
2. Heuristic baseline evaluation
|
| 8 |
+
3. Multi-model training with 5-fold CV
|
| 9 |
+
4. Publication-quality figure generation
|
| 10 |
+
5. Results summary
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python -m app.training.run_full_pipeline
|
| 14 |
+
|
| 15 |
+
# Or with custom paths:
|
| 16 |
+
python -m app.training.run_full_pipeline \\
|
| 17 |
+
--manifest data/training/manifest.csv \\
|
| 18 |
+
--models-dir models \\
|
| 19 |
+
--figures-dir figures
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
import argparse
|
| 25 |
+
import json
|
| 26 |
+
import sys
|
| 27 |
+
import time
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
|
| 30 |
+
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def main() -> None:
|
| 34 |
+
parser = argparse.ArgumentParser(description="AURIS Full Pipeline")
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--manifest", default="data/training/manifest.csv",
|
| 37 |
+
help="Path to training manifest CSV",
|
| 38 |
+
)
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--models-dir", default="models",
|
| 41 |
+
help="Directory for saved models",
|
| 42 |
+
)
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"--figures-dir", default="figures",
|
| 45 |
+
help="Directory for output figures",
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--skip-extract", action="store_true",
|
| 49 |
+
help="Skip feature extraction (use existing features.csv)",
|
| 50 |
+
)
|
| 51 |
+
args = parser.parse_args()
|
| 52 |
+
|
| 53 |
+
manifest_path = Path(args.manifest)
|
| 54 |
+
features_path = manifest_path.parent / "features.csv"
|
| 55 |
+
models_dir = Path(args.models_dir)
|
| 56 |
+
figures_dir = Path(args.figures_dir)
|
| 57 |
+
|
| 58 |
+
total_start = time.time()
|
| 59 |
+
|
| 60 |
+
# ── Step 1: Feature Extraction ─────────────────────────────
|
| 61 |
+
if not args.skip_extract:
|
| 62 |
+
print("\n" + "=" * 70)
|
| 63 |
+
print(" STEP 1 / 5 — Feature Extraction")
|
| 64 |
+
print("=" * 70)
|
| 65 |
+
|
| 66 |
+
from app.training.extract_features_batch import extract_batch
|
| 67 |
+
t0 = time.time()
|
| 68 |
+
features_path = extract_batch(manifest_path)
|
| 69 |
+
print(f"\n Extraction time: {time.time() - t0:.1f}s")
|
| 70 |
+
else:
|
| 71 |
+
print("\n [Skipping extraction — using existing features.csv]")
|
| 72 |
+
if not features_path.exists():
|
| 73 |
+
print(f" ERROR: {features_path} not found!")
|
| 74 |
+
sys.exit(1)
|
| 75 |
+
|
| 76 |
+
# ── Step 2: Heuristic Baseline ─────────────────────────────
|
| 77 |
+
print("\n" + "=" * 70)
|
| 78 |
+
print(" STEP 2 / 5 — Heuristic Baseline Evaluation")
|
| 79 |
+
print("=" * 70)
|
| 80 |
+
|
| 81 |
+
from app.training.evaluate import evaluate_heuristic_baseline
|
| 82 |
+
baseline_results = evaluate_heuristic_baseline(features_path)
|
| 83 |
+
|
| 84 |
+
# ── Step 3: Multi-Model Training ───────────────────────────
|
| 85 |
+
print("\n" + "=" * 70)
|
| 86 |
+
print(" STEP 3 / 5 — Multi-Model Training (5-Fold CV)")
|
| 87 |
+
print("=" * 70)
|
| 88 |
+
|
| 89 |
+
from app.training.train_classifier import train
|
| 90 |
+
train_results = train(features_path, models_dir)
|
| 91 |
+
|
| 92 |
+
# Save combined results (training + baseline) with arrays for visualization
|
| 93 |
+
combined_results = train_results["all_results"].copy()
|
| 94 |
+
|
| 95 |
+
# Add baseline heuristic results
|
| 96 |
+
if "heuristic_only" in baseline_results:
|
| 97 |
+
combined_results["Heuristic (no vocals)"] = baseline_results["heuristic_only"]
|
| 98 |
+
if "heuristic_vocals" in baseline_results:
|
| 99 |
+
combined_results["Heuristic + Vocals"] = baseline_results["heuristic_vocals"]
|
| 100 |
+
|
| 101 |
+
# Add metadata
|
| 102 |
+
combined_results["_best_model"] = train_results["best_model"]
|
| 103 |
+
combined_results["_n_samples"] = int(train_results["all_results"][
|
| 104 |
+
train_results["best_model"]
|
| 105 |
+
]["y_true"].__len__())
|
| 106 |
+
combined_results["_n_features"] = len(train_results["feature_cols"])
|
| 107 |
+
combined_results["_n_folds"] = 5
|
| 108 |
+
|
| 109 |
+
# Load feature importance from saved results
|
| 110 |
+
results_json_path = models_dir / "training_results.json"
|
| 111 |
+
if results_json_path.exists():
|
| 112 |
+
with open(results_json_path, "r") as f:
|
| 113 |
+
saved = json.load(f)
|
| 114 |
+
if "_feature_importance" in saved:
|
| 115 |
+
combined_results["_feature_importance"] = saved["_feature_importance"]
|
| 116 |
+
|
| 117 |
+
# Save full results with arrays for visualization
|
| 118 |
+
full_results_path = models_dir / "full_training_results.json"
|
| 119 |
+
serializable = {}
|
| 120 |
+
for key, value in combined_results.items():
|
| 121 |
+
if isinstance(value, dict):
|
| 122 |
+
serializable[key] = {
|
| 123 |
+
k: (v.tolist() if hasattr(v, "tolist") else v)
|
| 124 |
+
for k, v in value.items()
|
| 125 |
+
}
|
| 126 |
+
else:
|
| 127 |
+
serializable[key] = value
|
| 128 |
+
|
| 129 |
+
with open(full_results_path, "w") as f:
|
| 130 |
+
json.dump(serializable, f, indent=2)
|
| 131 |
+
|
| 132 |
+
# ── Step 4: Visualization ──────────────────────────────────
|
| 133 |
+
print("\n" + "=" * 70)
|
| 134 |
+
print(" STEP 4 / 5 — Publication-Quality Figures")
|
| 135 |
+
print("=" * 70)
|
| 136 |
+
|
| 137 |
+
from app.training.visualize_results import generate_all_figures
|
| 138 |
+
generate_all_figures(full_results_path, features_path, figures_dir)
|
| 139 |
+
|
| 140 |
+
# ── Step 5: Summary ────────────────────────────────────────
|
| 141 |
+
print("\n" + "=" * 70)
|
| 142 |
+
print(" STEP 5 / 5 — Final Summary")
|
| 143 |
+
print("=" * 70)
|
| 144 |
+
|
| 145 |
+
total_time = time.time() - total_start
|
| 146 |
+
|
| 147 |
+
print(f"\n Pipeline completed in {total_time:.1f}s ({total_time / 60:.1f}m)")
|
| 148 |
+
print(f"\n Best model: {train_results['best_model']}")
|
| 149 |
+
print(f" Best AUC: {train_results['best_auc']:.4f}")
|
| 150 |
+
print(f"\n Artifacts:")
|
| 151 |
+
print(f" Model: {train_results['model_path']}")
|
| 152 |
+
print(f" Results: {full_results_path}")
|
| 153 |
+
print(f" Figures: {figures_dir}/")
|
| 154 |
+
print(f" Features: {features_path}")
|
| 155 |
+
|
| 156 |
+
# Print all model results in a table
|
| 157 |
+
print(f"\n {'Model':<25} {'Acc':>7} {'Prec':>7} {'Rec':>7} {'F1':>7} {'AUC':>7}")
|
| 158 |
+
print(f" {'─' * 25} {'─' * 7} {'─' * 7} {'─' * 7} {'─' * 7} {'─' * 7}")
|
| 159 |
+
|
| 160 |
+
for name in sorted(train_results["all_results"].keys()):
|
| 161 |
+
data = train_results["all_results"][name]
|
| 162 |
+
print(
|
| 163 |
+
f" {name:<25} "
|
| 164 |
+
f"{data.get('accuracy', 0):>7.4f} "
|
| 165 |
+
f"{data.get('precision', 0):>7.4f} "
|
| 166 |
+
f"{data.get('recall', 0):>7.4f} "
|
| 167 |
+
f"{data.get('f1', 0):>7.4f} "
|
| 168 |
+
f"{data.get('roc_auc', 0):>7.4f}"
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
print("\n" + "=" * 70)
|
| 172 |
+
print(" AURIS Training Pipeline Complete!")
|
| 173 |
+
print("=" * 70)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
if __name__ == "__main__":
|
| 177 |
+
main()
|
app/training/visualize_results.py
ADDED
|
@@ -0,0 +1,632 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Publication-quality visualization pipeline for AURIS.
|
| 3 |
+
|
| 4 |
+
Generates all figures required for an academic paper / conference
|
| 5 |
+
submission on AI-generated music detection:
|
| 6 |
+
|
| 7 |
+
1. ROC curves (per-model overlay)
|
| 8 |
+
2. Precision-Recall curves (per-model overlay)
|
| 9 |
+
3. Confusion matrices (heatmap per model)
|
| 10 |
+
4. Model comparison bar chart (Accuracy, F1, AUC side-by-side)
|
| 11 |
+
5. Feature importance (top-N horizontal bar)
|
| 12 |
+
6. Feature correlation heatmap
|
| 13 |
+
7. Feature distribution violin plots (AI vs Human)
|
| 14 |
+
8. Training summary table (LaTeX-ready)
|
| 15 |
+
|
| 16 |
+
Usage:
|
| 17 |
+
python -m app.training.visualize_results \\
|
| 18 |
+
--results models/training_results.json \\
|
| 19 |
+
--features data/training/features.csv \\
|
| 20 |
+
--output figures/
|
| 21 |
+
|
| 22 |
+
All figures are saved at 300 DPI in both PNG and PDF formats
|
| 23 |
+
for direct inclusion in LaTeX / Word documents.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import argparse
|
| 29 |
+
import csv
|
| 30 |
+
import json
|
| 31 |
+
import sys
|
| 32 |
+
from pathlib import Path
|
| 33 |
+
from typing import Any
|
| 34 |
+
|
| 35 |
+
import numpy as np
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
import matplotlib
|
| 39 |
+
matplotlib.use("Agg") # Non-interactive backend for server/CI
|
| 40 |
+
import matplotlib.pyplot as plt
|
| 41 |
+
import matplotlib.ticker as mticker
|
| 42 |
+
from matplotlib.colors import LinearSegmentedColormap
|
| 43 |
+
except ImportError:
|
| 44 |
+
print("ERROR: matplotlib required. pip install matplotlib")
|
| 45 |
+
sys.exit(1)
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
import seaborn as sns
|
| 49 |
+
HAS_SEABORN = True
|
| 50 |
+
except ImportError:
|
| 51 |
+
HAS_SEABORN = False
|
| 52 |
+
|
| 53 |
+
from sklearn.metrics import (
|
| 54 |
+
auc,
|
| 55 |
+
confusion_matrix,
|
| 56 |
+
precision_recall_curve,
|
| 57 |
+
roc_curve,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 61 |
+
# Style configuration — academic paper quality
|
| 62 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 63 |
+
|
| 64 |
+
# Color palette — distinct, colorblind-friendly
|
| 65 |
+
MODEL_COLORS = {
|
| 66 |
+
"Logistic Regression": "#4363d8",
|
| 67 |
+
"Random Forest": "#3cb44b",
|
| 68 |
+
"Gradient Boosting": "#e6194b",
|
| 69 |
+
"SVM (RBF)": "#f58231",
|
| 70 |
+
"MLP Neural Network": "#911eb4",
|
| 71 |
+
"XGBoost": "#42d4f4",
|
| 72 |
+
"LightGBM": "#f032e6",
|
| 73 |
+
"Heuristic (no vocals)": "#808080",
|
| 74 |
+
"Heuristic + Vocals": "#a9a9a9",
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
AURIS_BLUE = "#1a73e8"
|
| 78 |
+
AURIS_RED = "#e8431a"
|
| 79 |
+
|
| 80 |
+
plt.rcParams.update({
|
| 81 |
+
"font.family": "serif",
|
| 82 |
+
"font.size": 11,
|
| 83 |
+
"axes.titlesize": 13,
|
| 84 |
+
"axes.labelsize": 12,
|
| 85 |
+
"xtick.labelsize": 10,
|
| 86 |
+
"ytick.labelsize": 10,
|
| 87 |
+
"legend.fontsize": 9,
|
| 88 |
+
"figure.dpi": 150,
|
| 89 |
+
"savefig.dpi": 300,
|
| 90 |
+
"savefig.bbox": "tight",
|
| 91 |
+
"savefig.pad_inches": 0.1,
|
| 92 |
+
"axes.grid": True,
|
| 93 |
+
"grid.alpha": 0.3,
|
| 94 |
+
"axes.spines.top": False,
|
| 95 |
+
"axes.spines.right": False,
|
| 96 |
+
})
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _save_fig(fig: plt.Figure, output_dir: Path, name: str) -> None:
|
| 100 |
+
"""Save figure in both PNG and PDF formats."""
|
| 101 |
+
fig.savefig(output_dir / f"{name}.png", format="png")
|
| 102 |
+
fig.savefig(output_dir / f"{name}.pdf", format="pdf")
|
| 103 |
+
plt.close(fig)
|
| 104 |
+
print(f" Saved: {name}.png / .pdf")
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _get_color(name: str) -> str:
|
| 108 |
+
return MODEL_COLORS.get(name, "#333333")
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 112 |
+
# Figure 1: ROC Curves
|
| 113 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 114 |
+
|
| 115 |
+
def plot_roc_curves(
|
| 116 |
+
results: dict[str, Any],
|
| 117 |
+
output_dir: Path,
|
| 118 |
+
) -> None:
|
| 119 |
+
"""Plot ROC curves for all models on the same axes."""
|
| 120 |
+
fig, ax = plt.subplots(figsize=(7, 6))
|
| 121 |
+
|
| 122 |
+
for name, data in results.items():
|
| 123 |
+
if name.startswith("_"):
|
| 124 |
+
continue
|
| 125 |
+
y_true = np.array(data["y_true"])
|
| 126 |
+
y_prob = np.array(data["y_prob"])
|
| 127 |
+
|
| 128 |
+
fpr, tpr, _ = roc_curve(y_true, y_prob)
|
| 129 |
+
roc_auc = auc(fpr, tpr)
|
| 130 |
+
|
| 131 |
+
ax.plot(
|
| 132 |
+
fpr, tpr,
|
| 133 |
+
color=_get_color(name),
|
| 134 |
+
linewidth=2,
|
| 135 |
+
label=f"{name} (AUC = {roc_auc:.3f})",
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Diagonal reference
|
| 139 |
+
ax.plot([0, 1], [0, 1], "k--", linewidth=1, alpha=0.5, label="Random (AUC = 0.500)")
|
| 140 |
+
|
| 141 |
+
ax.set_xlim([-0.02, 1.02])
|
| 142 |
+
ax.set_ylim([-0.02, 1.02])
|
| 143 |
+
ax.set_xlabel("False Positive Rate")
|
| 144 |
+
ax.set_ylabel("True Positive Rate")
|
| 145 |
+
ax.set_title("ROC Curves — AURIS Model Comparison")
|
| 146 |
+
ax.legend(loc="lower right", framealpha=0.9)
|
| 147 |
+
ax.set_aspect("equal")
|
| 148 |
+
|
| 149 |
+
_save_fig(fig, output_dir, "fig1_roc_curves")
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 153 |
+
# Figure 2: Precision-Recall Curves
|
| 154 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 155 |
+
|
| 156 |
+
def plot_pr_curves(
|
| 157 |
+
results: dict[str, Any],
|
| 158 |
+
output_dir: Path,
|
| 159 |
+
) -> None:
|
| 160 |
+
"""Plot Precision-Recall curves for all models."""
|
| 161 |
+
fig, ax = plt.subplots(figsize=(7, 6))
|
| 162 |
+
|
| 163 |
+
for name, data in results.items():
|
| 164 |
+
if name.startswith("_"):
|
| 165 |
+
continue
|
| 166 |
+
y_true = np.array(data["y_true"])
|
| 167 |
+
y_prob = np.array(data["y_prob"])
|
| 168 |
+
|
| 169 |
+
precision, recall, _ = precision_recall_curve(y_true, y_prob)
|
| 170 |
+
pr_auc = auc(recall, precision)
|
| 171 |
+
|
| 172 |
+
ax.plot(
|
| 173 |
+
recall, precision,
|
| 174 |
+
color=_get_color(name),
|
| 175 |
+
linewidth=2,
|
| 176 |
+
label=f"{name} (AP = {pr_auc:.3f})",
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# Baseline: proportion of positives
|
| 180 |
+
all_y = []
|
| 181 |
+
for name, data in results.items():
|
| 182 |
+
if not name.startswith("_"):
|
| 183 |
+
all_y = data["y_true"]
|
| 184 |
+
break
|
| 185 |
+
baseline = np.mean(all_y) if all_y else 0.5
|
| 186 |
+
ax.axhline(y=baseline, color="k", linestyle="--", linewidth=1, alpha=0.5,
|
| 187 |
+
label=f"Baseline ({baseline:.2f})")
|
| 188 |
+
|
| 189 |
+
ax.set_xlim([-0.02, 1.02])
|
| 190 |
+
ax.set_ylim([-0.02, 1.05])
|
| 191 |
+
ax.set_xlabel("Recall")
|
| 192 |
+
ax.set_ylabel("Precision")
|
| 193 |
+
ax.set_title("Precision-Recall Curves — AURIS Model Comparison")
|
| 194 |
+
ax.legend(loc="lower left", framealpha=0.9)
|
| 195 |
+
|
| 196 |
+
_save_fig(fig, output_dir, "fig2_pr_curves")
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 200 |
+
# Figure 3: Confusion Matrices
|
| 201 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 202 |
+
|
| 203 |
+
def plot_confusion_matrices(
|
| 204 |
+
results: dict[str, Any],
|
| 205 |
+
output_dir: Path,
|
| 206 |
+
) -> None:
|
| 207 |
+
"""Plot confusion matrix heatmap for each model."""
|
| 208 |
+
model_names = [k for k in results if not k.startswith("_")]
|
| 209 |
+
n_models = len(model_names)
|
| 210 |
+
cols = min(3, n_models)
|
| 211 |
+
rows = (n_models + cols - 1) // cols
|
| 212 |
+
|
| 213 |
+
fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4.5 * rows))
|
| 214 |
+
if n_models == 1:
|
| 215 |
+
axes = np.array([axes])
|
| 216 |
+
axes = axes.flatten()
|
| 217 |
+
|
| 218 |
+
cmap = LinearSegmentedColormap.from_list("auris", ["#ffffff", AURIS_BLUE])
|
| 219 |
+
|
| 220 |
+
for idx, name in enumerate(model_names):
|
| 221 |
+
ax = axes[idx]
|
| 222 |
+
data = results[name]
|
| 223 |
+
y_true = np.array(data["y_true"])
|
| 224 |
+
y_pred = np.array(data["y_pred"])
|
| 225 |
+
|
| 226 |
+
cm = confusion_matrix(y_true, y_pred)
|
| 227 |
+
cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True)
|
| 228 |
+
|
| 229 |
+
im = ax.imshow(cm_norm, interpolation="nearest", cmap=cmap, vmin=0, vmax=1)
|
| 230 |
+
|
| 231 |
+
# Annotate cells with count and percentage
|
| 232 |
+
for i in range(2):
|
| 233 |
+
for j in range(2):
|
| 234 |
+
color = "white" if cm_norm[i, j] > 0.6 else "black"
|
| 235 |
+
ax.text(j, i, f"{cm[i, j]}\n({cm_norm[i, j]:.1%})",
|
| 236 |
+
ha="center", va="center", fontsize=12, color=color,
|
| 237 |
+
fontweight="bold")
|
| 238 |
+
|
| 239 |
+
ax.set_xticks([0, 1])
|
| 240 |
+
ax.set_yticks([0, 1])
|
| 241 |
+
ax.set_xticklabels(["Human", "AI"])
|
| 242 |
+
ax.set_yticklabels(["Human", "AI"])
|
| 243 |
+
ax.set_xlabel("Predicted")
|
| 244 |
+
ax.set_ylabel("Actual")
|
| 245 |
+
ax.set_title(name, fontsize=11)
|
| 246 |
+
|
| 247 |
+
# Hide unused axes
|
| 248 |
+
for idx in range(n_models, len(axes)):
|
| 249 |
+
axes[idx].set_visible(False)
|
| 250 |
+
|
| 251 |
+
fig.suptitle("Confusion Matrices — AURIS Model Comparison", fontsize=14, y=1.02)
|
| 252 |
+
fig.tight_layout()
|
| 253 |
+
_save_fig(fig, output_dir, "fig3_confusion_matrices")
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 257 |
+
# Figure 4: Model Comparison Bar Chart
|
| 258 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 259 |
+
|
| 260 |
+
def plot_model_comparison(
|
| 261 |
+
results: dict[str, Any],
|
| 262 |
+
output_dir: Path,
|
| 263 |
+
) -> None:
|
| 264 |
+
"""Bar chart comparing Accuracy, F1, Precision, Recall, AUC across models."""
|
| 265 |
+
model_names = [k for k in results if not k.startswith("_")]
|
| 266 |
+
metrics = ["accuracy", "precision", "recall", "f1", "roc_auc"]
|
| 267 |
+
metric_labels = ["Accuracy", "Precision", "Recall", "F1 Score", "ROC-AUC"]
|
| 268 |
+
|
| 269 |
+
x = np.arange(len(model_names))
|
| 270 |
+
width = 0.15
|
| 271 |
+
metric_colors = ["#4363d8", "#3cb44b", "#e6194b", "#f58231", "#911eb4"]
|
| 272 |
+
|
| 273 |
+
fig, ax = plt.subplots(figsize=(max(10, len(model_names) * 2), 6))
|
| 274 |
+
|
| 275 |
+
for i, (metric, label, color) in enumerate(zip(metrics, metric_labels, metric_colors)):
|
| 276 |
+
values = []
|
| 277 |
+
for name in model_names:
|
| 278 |
+
val = results[name].get(metric, 0)
|
| 279 |
+
values.append(val if val is not None else 0)
|
| 280 |
+
bars = ax.bar(x + i * width, values, width, label=label, color=color, alpha=0.85)
|
| 281 |
+
|
| 282 |
+
# Value labels on bars
|
| 283 |
+
for bar, val in zip(bars, values):
|
| 284 |
+
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,
|
| 285 |
+
f"{val:.2f}", ha="center", va="bottom", fontsize=7)
|
| 286 |
+
|
| 287 |
+
ax.set_xlabel("Model")
|
| 288 |
+
ax.set_ylabel("Score")
|
| 289 |
+
ax.set_title("Model Performance Comparison — AURIS")
|
| 290 |
+
ax.set_xticks(x + width * 2)
|
| 291 |
+
ax.set_xticklabels(model_names, rotation=25, ha="right")
|
| 292 |
+
ax.set_ylim([0, 1.12])
|
| 293 |
+
ax.legend(loc="upper right", ncol=5, framealpha=0.9)
|
| 294 |
+
ax.yaxis.set_major_formatter(mticker.PercentFormatter(1.0))
|
| 295 |
+
|
| 296 |
+
fig.tight_layout()
|
| 297 |
+
_save_fig(fig, output_dir, "fig4_model_comparison")
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 301 |
+
# Figure 5: Feature Importance
|
| 302 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 303 |
+
|
| 304 |
+
def plot_feature_importance(
|
| 305 |
+
results: dict[str, Any],
|
| 306 |
+
output_dir: Path,
|
| 307 |
+
top_n: int = 20,
|
| 308 |
+
) -> None:
|
| 309 |
+
"""Horizontal bar chart of top-N feature importances."""
|
| 310 |
+
importance = results.get("_feature_importance")
|
| 311 |
+
if not importance:
|
| 312 |
+
print(" Skipping feature importance — no data available")
|
| 313 |
+
return
|
| 314 |
+
|
| 315 |
+
# Sort and take top N
|
| 316 |
+
sorted_features = sorted(importance.items(), key=lambda x: x[1], reverse=True)[:top_n]
|
| 317 |
+
names = [f[0] for f in reversed(sorted_features)]
|
| 318 |
+
values = [f[1] for f in reversed(sorted_features)]
|
| 319 |
+
|
| 320 |
+
fig, ax = plt.subplots(figsize=(8, max(6, top_n * 0.35)))
|
| 321 |
+
|
| 322 |
+
colors = plt.cm.Blues(np.linspace(0.3, 0.9, len(names)))
|
| 323 |
+
bars = ax.barh(names, values, color=colors, edgecolor="white", linewidth=0.5)
|
| 324 |
+
|
| 325 |
+
# Value labels
|
| 326 |
+
for bar, val in zip(bars, values):
|
| 327 |
+
ax.text(bar.get_width() + 0.002, bar.get_y() + bar.get_height() / 2,
|
| 328 |
+
f"{val:.4f}", ha="left", va="center", fontsize=9)
|
| 329 |
+
|
| 330 |
+
ax.set_xlabel("Relative Importance")
|
| 331 |
+
ax.set_title(f"Top {top_n} Feature Importances — AURIS ({results.get('_best_model', '')})")
|
| 332 |
+
ax.set_xlim([0, max(values) * 1.15])
|
| 333 |
+
|
| 334 |
+
fig.tight_layout()
|
| 335 |
+
_save_fig(fig, output_dir, "fig5_feature_importance")
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 339 |
+
# Figure 6: Feature Correlation Heatmap
|
| 340 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 341 |
+
|
| 342 |
+
def plot_correlation_heatmap(
|
| 343 |
+
features_csv: Path,
|
| 344 |
+
output_dir: Path,
|
| 345 |
+
) -> None:
|
| 346 |
+
"""Correlation heatmap of all features."""
|
| 347 |
+
X, y, feature_cols = _load_features_with_names(features_csv)
|
| 348 |
+
|
| 349 |
+
corr = np.corrcoef(X.T)
|
| 350 |
+
|
| 351 |
+
fig, ax = plt.subplots(figsize=(max(12, len(feature_cols) * 0.4),
|
| 352 |
+
max(10, len(feature_cols) * 0.35)))
|
| 353 |
+
|
| 354 |
+
cmap = "RdBu_r" if HAS_SEABORN else "coolwarm"
|
| 355 |
+
|
| 356 |
+
if HAS_SEABORN:
|
| 357 |
+
sns.heatmap(
|
| 358 |
+
corr, xticklabels=feature_cols, yticklabels=feature_cols,
|
| 359 |
+
cmap=cmap, center=0, vmin=-1, vmax=1,
|
| 360 |
+
square=True, linewidths=0.5, ax=ax,
|
| 361 |
+
cbar_kws={"shrink": 0.8, "label": "Pearson Correlation"},
|
| 362 |
+
)
|
| 363 |
+
else:
|
| 364 |
+
im = ax.imshow(corr, cmap=cmap, vmin=-1, vmax=1, aspect="auto")
|
| 365 |
+
ax.set_xticks(range(len(feature_cols)))
|
| 366 |
+
ax.set_yticks(range(len(feature_cols)))
|
| 367 |
+
ax.set_xticklabels(feature_cols, rotation=90, fontsize=7)
|
| 368 |
+
ax.set_yticklabels(feature_cols, fontsize=7)
|
| 369 |
+
fig.colorbar(im, ax=ax, shrink=0.8, label="Pearson Correlation")
|
| 370 |
+
|
| 371 |
+
ax.set_title("Feature Correlation Matrix — AURIS", fontsize=14, pad=20)
|
| 372 |
+
|
| 373 |
+
fig.tight_layout()
|
| 374 |
+
_save_fig(fig, output_dir, "fig6_correlation_heatmap")
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 378 |
+
# Figure 7: Feature Distribution (Violin / Box Plots)
|
| 379 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 380 |
+
|
| 381 |
+
def plot_feature_distributions(
|
| 382 |
+
features_csv: Path,
|
| 383 |
+
output_dir: Path,
|
| 384 |
+
results: dict[str, Any] | None = None,
|
| 385 |
+
top_n: int = 12,
|
| 386 |
+
) -> None:
|
| 387 |
+
"""Violin plots showing feature distributions for AI vs Human."""
|
| 388 |
+
X, y, feature_cols = _load_features_with_names(features_csv)
|
| 389 |
+
|
| 390 |
+
# Select top features by importance, or first N
|
| 391 |
+
if results and "_feature_importance" in results:
|
| 392 |
+
sorted_feats = sorted(
|
| 393 |
+
results["_feature_importance"].items(),
|
| 394 |
+
key=lambda x: x[1], reverse=True,
|
| 395 |
+
)
|
| 396 |
+
selected = [f[0] for f in sorted_feats[:top_n] if f[0] in feature_cols]
|
| 397 |
+
else:
|
| 398 |
+
selected = feature_cols[:top_n]
|
| 399 |
+
|
| 400 |
+
n_features = len(selected)
|
| 401 |
+
cols = 3
|
| 402 |
+
rows = (n_features + cols - 1) // cols
|
| 403 |
+
|
| 404 |
+
fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 3.5 * rows))
|
| 405 |
+
axes = axes.flatten()
|
| 406 |
+
|
| 407 |
+
for idx, feat_name in enumerate(selected):
|
| 408 |
+
ax = axes[idx]
|
| 409 |
+
col_idx = feature_cols.index(feat_name)
|
| 410 |
+
|
| 411 |
+
human_vals = X[y == 0, col_idx]
|
| 412 |
+
ai_vals = X[y == 1, col_idx]
|
| 413 |
+
|
| 414 |
+
if HAS_SEABORN:
|
| 415 |
+
data_list = []
|
| 416 |
+
for val in human_vals:
|
| 417 |
+
data_list.append({"Feature": feat_name, "Value": val, "Class": "Human"})
|
| 418 |
+
for val in ai_vals:
|
| 419 |
+
data_list.append({"Feature": feat_name, "Value": val, "Class": "AI"})
|
| 420 |
+
|
| 421 |
+
import pandas as pd
|
| 422 |
+
df = pd.DataFrame(data_list)
|
| 423 |
+
sns.violinplot(
|
| 424 |
+
data=df, x="Class", y="Value",
|
| 425 |
+
palette={"Human": AURIS_BLUE, "AI": AURIS_RED},
|
| 426 |
+
ax=ax, inner="quartile", linewidth=1,
|
| 427 |
+
)
|
| 428 |
+
else:
|
| 429 |
+
parts = ax.violinplot(
|
| 430 |
+
[human_vals, ai_vals],
|
| 431 |
+
positions=[0, 1],
|
| 432 |
+
showmeans=True, showmedians=True,
|
| 433 |
+
)
|
| 434 |
+
ax.set_xticks([0, 1])
|
| 435 |
+
ax.set_xticklabels(["Human", "AI"])
|
| 436 |
+
|
| 437 |
+
ax.set_title(feat_name, fontsize=10)
|
| 438 |
+
ax.set_xlabel("")
|
| 439 |
+
|
| 440 |
+
for idx in range(n_features, len(axes)):
|
| 441 |
+
axes[idx].set_visible(False)
|
| 442 |
+
|
| 443 |
+
fig.suptitle("Feature Distributions — AI vs Human", fontsize=14, y=1.01)
|
| 444 |
+
fig.tight_layout()
|
| 445 |
+
_save_fig(fig, output_dir, "fig7_feature_distributions")
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 449 |
+
# Figure 8: Training Summary Table (LaTeX)
|
| 450 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 451 |
+
|
| 452 |
+
def generate_latex_table(
|
| 453 |
+
results: dict[str, Any],
|
| 454 |
+
output_dir: Path,
|
| 455 |
+
) -> None:
|
| 456 |
+
"""Generate LaTeX-ready comparison table."""
|
| 457 |
+
model_names = [k for k in results if not k.startswith("_")]
|
| 458 |
+
best_model = results.get("_best_model", "")
|
| 459 |
+
|
| 460 |
+
lines = [
|
| 461 |
+
r"\begin{table}[htbp]",
|
| 462 |
+
r"\centering",
|
| 463 |
+
r"\caption{AURIS Model Performance Comparison}",
|
| 464 |
+
r"\label{tab:model-comparison}",
|
| 465 |
+
r"\begin{tabular}{lccccc}",
|
| 466 |
+
r"\toprule",
|
| 467 |
+
r"Model & Accuracy & Precision & Recall & F1 & ROC-AUC \\",
|
| 468 |
+
r"\midrule",
|
| 469 |
+
]
|
| 470 |
+
|
| 471 |
+
for name in model_names:
|
| 472 |
+
data = results[name]
|
| 473 |
+
acc = data.get("accuracy", 0)
|
| 474 |
+
prec = data.get("precision", 0)
|
| 475 |
+
rec = data.get("recall", 0)
|
| 476 |
+
f1 = data.get("f1", 0)
|
| 477 |
+
roc = data.get("roc_auc", 0)
|
| 478 |
+
|
| 479 |
+
# Bold the best model
|
| 480 |
+
prefix = r"\textbf{" if name == best_model else ""
|
| 481 |
+
suffix = "}" if name == best_model else ""
|
| 482 |
+
|
| 483 |
+
row = (
|
| 484 |
+
f" {prefix}{name}{suffix} & "
|
| 485 |
+
f"{prefix}{acc:.4f}{suffix} & "
|
| 486 |
+
f"{prefix}{prec:.4f}{suffix} & "
|
| 487 |
+
f"{prefix}{rec:.4f}{suffix} & "
|
| 488 |
+
f"{prefix}{f1:.4f}{suffix} & "
|
| 489 |
+
f"{prefix}{roc:.4f}{suffix} \\\\"
|
| 490 |
+
)
|
| 491 |
+
lines.append(row)
|
| 492 |
+
|
| 493 |
+
lines.extend([
|
| 494 |
+
r"\bottomrule",
|
| 495 |
+
r"\end{tabular}",
|
| 496 |
+
r"\end{table}",
|
| 497 |
+
])
|
| 498 |
+
|
| 499 |
+
latex_content = "\n".join(lines)
|
| 500 |
+
table_path = output_dir / "table1_model_comparison.tex"
|
| 501 |
+
table_path.write_text(latex_content, encoding="utf-8")
|
| 502 |
+
print(f" Saved: table1_model_comparison.tex")
|
| 503 |
+
|
| 504 |
+
# Also save as markdown for README
|
| 505 |
+
md_lines = [
|
| 506 |
+
"| Model | Accuracy | Precision | Recall | F1 | ROC-AUC |",
|
| 507 |
+
"|-------|----------|-----------|--------|-----|---------|",
|
| 508 |
+
]
|
| 509 |
+
for name in model_names:
|
| 510 |
+
data = results[name]
|
| 511 |
+
bold = "**" if name == best_model else ""
|
| 512 |
+
md_lines.append(
|
| 513 |
+
f"| {bold}{name}{bold} | "
|
| 514 |
+
f"{data.get('accuracy', 0):.4f} | "
|
| 515 |
+
f"{data.get('precision', 0):.4f} | "
|
| 516 |
+
f"{data.get('recall', 0):.4f} | "
|
| 517 |
+
f"{data.get('f1', 0):.4f} | "
|
| 518 |
+
f"{data.get('roc_auc', 0):.4f} |"
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
md_path = output_dir / "table1_model_comparison.md"
|
| 522 |
+
md_path.write_text("\n".join(md_lines), encoding="utf-8")
|
| 523 |
+
print(f" Saved: table1_model_comparison.md")
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
# ═════════════════════════════════════════════���═════════════════════════
|
| 527 |
+
# Utilities
|
| 528 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 529 |
+
|
| 530 |
+
def _load_features_with_names(
|
| 531 |
+
features_csv: Path,
|
| 532 |
+
) -> tuple[np.ndarray, np.ndarray, list[str]]:
|
| 533 |
+
"""Load features CSV returning X, y, and column names."""
|
| 534 |
+
rows = []
|
| 535 |
+
labels = []
|
| 536 |
+
|
| 537 |
+
with open(features_csv, "r", encoding="utf-8") as f:
|
| 538 |
+
reader = csv.DictReader(f)
|
| 539 |
+
feature_cols = [
|
| 540 |
+
c for c in reader.fieldnames
|
| 541 |
+
if c not in ("file_path", "label_int")
|
| 542 |
+
]
|
| 543 |
+
for row in reader:
|
| 544 |
+
feat_values = []
|
| 545 |
+
for col in feature_cols:
|
| 546 |
+
try:
|
| 547 |
+
feat_values.append(float(row[col]))
|
| 548 |
+
except (ValueError, KeyError):
|
| 549 |
+
feat_values.append(0.0)
|
| 550 |
+
rows.append(feat_values)
|
| 551 |
+
labels.append(int(row["label_int"]))
|
| 552 |
+
|
| 553 |
+
X = np.nan_to_num(np.array(rows, dtype=np.float32), nan=0.0)
|
| 554 |
+
y = np.array(labels, dtype=np.int32)
|
| 555 |
+
|
| 556 |
+
return X, y, feature_cols
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 560 |
+
# Main entry point
|
| 561 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 562 |
+
|
| 563 |
+
def generate_all_figures(
|
| 564 |
+
results_path: str | Path,
|
| 565 |
+
features_csv: str | Path,
|
| 566 |
+
output_dir: str | Path = "figures",
|
| 567 |
+
) -> None:
|
| 568 |
+
"""Generate all publication-quality figures."""
|
| 569 |
+
results_path = Path(results_path)
|
| 570 |
+
features_csv = Path(features_csv)
|
| 571 |
+
output_dir = Path(output_dir)
|
| 572 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 573 |
+
|
| 574 |
+
print(f"\nLoading results from {results_path}...")
|
| 575 |
+
with open(results_path, "r") as f:
|
| 576 |
+
saved_results = json.load(f)
|
| 577 |
+
|
| 578 |
+
# We need y_true/y_pred/y_prob arrays — check if they're stored
|
| 579 |
+
# If not in the saved JSON, we can't plot ROC/PR curves
|
| 580 |
+
has_arrays = any(
|
| 581 |
+
"y_true" in v for k, v in saved_results.items()
|
| 582 |
+
if isinstance(v, dict) and not k.startswith("_")
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
print(f"\nGenerating figures in {output_dir}/...\n")
|
| 586 |
+
|
| 587 |
+
if has_arrays:
|
| 588 |
+
print("[1/8] ROC Curves")
|
| 589 |
+
plot_roc_curves(saved_results, output_dir)
|
| 590 |
+
|
| 591 |
+
print("[2/8] Precision-Recall Curves")
|
| 592 |
+
plot_pr_curves(saved_results, output_dir)
|
| 593 |
+
|
| 594 |
+
print("[3/8] Confusion Matrices")
|
| 595 |
+
plot_confusion_matrices(saved_results, output_dir)
|
| 596 |
+
else:
|
| 597 |
+
print("[1-3/8] Skipping ROC/PR/CM — no per-sample predictions stored")
|
| 598 |
+
|
| 599 |
+
print("[4/8] Model Comparison Bar Chart")
|
| 600 |
+
plot_model_comparison(saved_results, output_dir)
|
| 601 |
+
|
| 602 |
+
print("[5/8] Feature Importance")
|
| 603 |
+
plot_feature_importance(saved_results, output_dir)
|
| 604 |
+
|
| 605 |
+
print("[6/8] Feature Correlation Heatmap")
|
| 606 |
+
plot_correlation_heatmap(features_csv, output_dir)
|
| 607 |
+
|
| 608 |
+
print("[7/8] Feature Distributions (AI vs Human)")
|
| 609 |
+
plot_feature_distributions(features_csv, output_dir, saved_results)
|
| 610 |
+
|
| 611 |
+
print("[8/8] LaTeX / Markdown Tables")
|
| 612 |
+
generate_latex_table(saved_results, output_dir)
|
| 613 |
+
|
| 614 |
+
print(f"\nAll figures saved to {output_dir}/")
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
if __name__ == "__main__":
|
| 618 |
+
parser = argparse.ArgumentParser(description="AURIS Visualization Pipeline")
|
| 619 |
+
parser.add_argument(
|
| 620 |
+
"--results", default="models/training_results.json",
|
| 621 |
+
help="Path to training_results.json",
|
| 622 |
+
)
|
| 623 |
+
parser.add_argument(
|
| 624 |
+
"--features", default="data/training/features.csv",
|
| 625 |
+
help="Path to features.csv",
|
| 626 |
+
)
|
| 627 |
+
parser.add_argument(
|
| 628 |
+
"--output", default="figures",
|
| 629 |
+
help="Output directory for figures",
|
| 630 |
+
)
|
| 631 |
+
args = parser.parse_args()
|
| 632 |
+
generate_all_figures(args.results, args.features, args.output)
|