Rthur2003 commited on
Commit
36a8a94
·
1 Parent(s): 5446f0d

feat: add full training and evaluation pipeline for AURIS

Browse files
app/training/run_full_pipeline.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AURIS Full Training & Evaluation Pipeline.
3
+
4
+ Orchestrates the complete ML pipeline end-to-end:
5
+
6
+ 1. Feature extraction (49 audio features + 14 vocal features)
7
+ 2. Heuristic baseline evaluation
8
+ 3. Multi-model training with 5-fold CV
9
+ 4. Publication-quality figure generation
10
+ 5. Results summary
11
+
12
+ Usage:
13
+ python -m app.training.run_full_pipeline
14
+
15
+ # Or with custom paths:
16
+ python -m app.training.run_full_pipeline \\
17
+ --manifest data/training/manifest.csv \\
18
+ --models-dir models \\
19
+ --figures-dir figures
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import argparse
25
+ import json
26
+ import sys
27
+ import time
28
+ from pathlib import Path
29
+
30
+ sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
31
+
32
+
33
+ def main() -> None:
34
+ parser = argparse.ArgumentParser(description="AURIS Full Pipeline")
35
+ parser.add_argument(
36
+ "--manifest", default="data/training/manifest.csv",
37
+ help="Path to training manifest CSV",
38
+ )
39
+ parser.add_argument(
40
+ "--models-dir", default="models",
41
+ help="Directory for saved models",
42
+ )
43
+ parser.add_argument(
44
+ "--figures-dir", default="figures",
45
+ help="Directory for output figures",
46
+ )
47
+ parser.add_argument(
48
+ "--skip-extract", action="store_true",
49
+ help="Skip feature extraction (use existing features.csv)",
50
+ )
51
+ args = parser.parse_args()
52
+
53
+ manifest_path = Path(args.manifest)
54
+ features_path = manifest_path.parent / "features.csv"
55
+ models_dir = Path(args.models_dir)
56
+ figures_dir = Path(args.figures_dir)
57
+
58
+ total_start = time.time()
59
+
60
+ # ── Step 1: Feature Extraction ─────────────────────────────
61
+ if not args.skip_extract:
62
+ print("\n" + "=" * 70)
63
+ print(" STEP 1 / 5 — Feature Extraction")
64
+ print("=" * 70)
65
+
66
+ from app.training.extract_features_batch import extract_batch
67
+ t0 = time.time()
68
+ features_path = extract_batch(manifest_path)
69
+ print(f"\n Extraction time: {time.time() - t0:.1f}s")
70
+ else:
71
+ print("\n [Skipping extraction — using existing features.csv]")
72
+ if not features_path.exists():
73
+ print(f" ERROR: {features_path} not found!")
74
+ sys.exit(1)
75
+
76
+ # ── Step 2: Heuristic Baseline ─────────────────────────────
77
+ print("\n" + "=" * 70)
78
+ print(" STEP 2 / 5 — Heuristic Baseline Evaluation")
79
+ print("=" * 70)
80
+
81
+ from app.training.evaluate import evaluate_heuristic_baseline
82
+ baseline_results = evaluate_heuristic_baseline(features_path)
83
+
84
+ # ── Step 3: Multi-Model Training ───────────────────────────
85
+ print("\n" + "=" * 70)
86
+ print(" STEP 3 / 5 — Multi-Model Training (5-Fold CV)")
87
+ print("=" * 70)
88
+
89
+ from app.training.train_classifier import train
90
+ train_results = train(features_path, models_dir)
91
+
92
+ # Save combined results (training + baseline) with arrays for visualization
93
+ combined_results = train_results["all_results"].copy()
94
+
95
+ # Add baseline heuristic results
96
+ if "heuristic_only" in baseline_results:
97
+ combined_results["Heuristic (no vocals)"] = baseline_results["heuristic_only"]
98
+ if "heuristic_vocals" in baseline_results:
99
+ combined_results["Heuristic + Vocals"] = baseline_results["heuristic_vocals"]
100
+
101
+ # Add metadata
102
+ combined_results["_best_model"] = train_results["best_model"]
103
+ combined_results["_n_samples"] = int(train_results["all_results"][
104
+ train_results["best_model"]
105
+ ]["y_true"].__len__())
106
+ combined_results["_n_features"] = len(train_results["feature_cols"])
107
+ combined_results["_n_folds"] = 5
108
+
109
+ # Load feature importance from saved results
110
+ results_json_path = models_dir / "training_results.json"
111
+ if results_json_path.exists():
112
+ with open(results_json_path, "r") as f:
113
+ saved = json.load(f)
114
+ if "_feature_importance" in saved:
115
+ combined_results["_feature_importance"] = saved["_feature_importance"]
116
+
117
+ # Save full results with arrays for visualization
118
+ full_results_path = models_dir / "full_training_results.json"
119
+ serializable = {}
120
+ for key, value in combined_results.items():
121
+ if isinstance(value, dict):
122
+ serializable[key] = {
123
+ k: (v.tolist() if hasattr(v, "tolist") else v)
124
+ for k, v in value.items()
125
+ }
126
+ else:
127
+ serializable[key] = value
128
+
129
+ with open(full_results_path, "w") as f:
130
+ json.dump(serializable, f, indent=2)
131
+
132
+ # ── Step 4: Visualization ──────────────────────────────────
133
+ print("\n" + "=" * 70)
134
+ print(" STEP 4 / 5 — Publication-Quality Figures")
135
+ print("=" * 70)
136
+
137
+ from app.training.visualize_results import generate_all_figures
138
+ generate_all_figures(full_results_path, features_path, figures_dir)
139
+
140
+ # ── Step 5: Summary ────────────────────────────────────────
141
+ print("\n" + "=" * 70)
142
+ print(" STEP 5 / 5 — Final Summary")
143
+ print("=" * 70)
144
+
145
+ total_time = time.time() - total_start
146
+
147
+ print(f"\n Pipeline completed in {total_time:.1f}s ({total_time / 60:.1f}m)")
148
+ print(f"\n Best model: {train_results['best_model']}")
149
+ print(f" Best AUC: {train_results['best_auc']:.4f}")
150
+ print(f"\n Artifacts:")
151
+ print(f" Model: {train_results['model_path']}")
152
+ print(f" Results: {full_results_path}")
153
+ print(f" Figures: {figures_dir}/")
154
+ print(f" Features: {features_path}")
155
+
156
+ # Print all model results in a table
157
+ print(f"\n {'Model':<25} {'Acc':>7} {'Prec':>7} {'Rec':>7} {'F1':>7} {'AUC':>7}")
158
+ print(f" {'─' * 25} {'─' * 7} {'─' * 7} {'─' * 7} {'─' * 7} {'─' * 7}")
159
+
160
+ for name in sorted(train_results["all_results"].keys()):
161
+ data = train_results["all_results"][name]
162
+ print(
163
+ f" {name:<25} "
164
+ f"{data.get('accuracy', 0):>7.4f} "
165
+ f"{data.get('precision', 0):>7.4f} "
166
+ f"{data.get('recall', 0):>7.4f} "
167
+ f"{data.get('f1', 0):>7.4f} "
168
+ f"{data.get('roc_auc', 0):>7.4f}"
169
+ )
170
+
171
+ print("\n" + "=" * 70)
172
+ print(" AURIS Training Pipeline Complete!")
173
+ print("=" * 70)
174
+
175
+
176
+ if __name__ == "__main__":
177
+ main()
app/training/visualize_results.py ADDED
@@ -0,0 +1,632 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Publication-quality visualization pipeline for AURIS.
3
+
4
+ Generates all figures required for an academic paper / conference
5
+ submission on AI-generated music detection:
6
+
7
+ 1. ROC curves (per-model overlay)
8
+ 2. Precision-Recall curves (per-model overlay)
9
+ 3. Confusion matrices (heatmap per model)
10
+ 4. Model comparison bar chart (Accuracy, F1, AUC side-by-side)
11
+ 5. Feature importance (top-N horizontal bar)
12
+ 6. Feature correlation heatmap
13
+ 7. Feature distribution violin plots (AI vs Human)
14
+ 8. Training summary table (LaTeX-ready)
15
+
16
+ Usage:
17
+ python -m app.training.visualize_results \\
18
+ --results models/training_results.json \\
19
+ --features data/training/features.csv \\
20
+ --output figures/
21
+
22
+ All figures are saved at 300 DPI in both PNG and PDF formats
23
+ for direct inclusion in LaTeX / Word documents.
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import argparse
29
+ import csv
30
+ import json
31
+ import sys
32
+ from pathlib import Path
33
+ from typing import Any
34
+
35
+ import numpy as np
36
+
37
+ try:
38
+ import matplotlib
39
+ matplotlib.use("Agg") # Non-interactive backend for server/CI
40
+ import matplotlib.pyplot as plt
41
+ import matplotlib.ticker as mticker
42
+ from matplotlib.colors import LinearSegmentedColormap
43
+ except ImportError:
44
+ print("ERROR: matplotlib required. pip install matplotlib")
45
+ sys.exit(1)
46
+
47
+ try:
48
+ import seaborn as sns
49
+ HAS_SEABORN = True
50
+ except ImportError:
51
+ HAS_SEABORN = False
52
+
53
+ from sklearn.metrics import (
54
+ auc,
55
+ confusion_matrix,
56
+ precision_recall_curve,
57
+ roc_curve,
58
+ )
59
+
60
+ # ═══════════════════════════════════════════════════════════════════════
61
+ # Style configuration — academic paper quality
62
+ # ═══════════════════════════════════════════════════════════════════════
63
+
64
+ # Color palette — distinct, colorblind-friendly
65
+ MODEL_COLORS = {
66
+ "Logistic Regression": "#4363d8",
67
+ "Random Forest": "#3cb44b",
68
+ "Gradient Boosting": "#e6194b",
69
+ "SVM (RBF)": "#f58231",
70
+ "MLP Neural Network": "#911eb4",
71
+ "XGBoost": "#42d4f4",
72
+ "LightGBM": "#f032e6",
73
+ "Heuristic (no vocals)": "#808080",
74
+ "Heuristic + Vocals": "#a9a9a9",
75
+ }
76
+
77
+ AURIS_BLUE = "#1a73e8"
78
+ AURIS_RED = "#e8431a"
79
+
80
+ plt.rcParams.update({
81
+ "font.family": "serif",
82
+ "font.size": 11,
83
+ "axes.titlesize": 13,
84
+ "axes.labelsize": 12,
85
+ "xtick.labelsize": 10,
86
+ "ytick.labelsize": 10,
87
+ "legend.fontsize": 9,
88
+ "figure.dpi": 150,
89
+ "savefig.dpi": 300,
90
+ "savefig.bbox": "tight",
91
+ "savefig.pad_inches": 0.1,
92
+ "axes.grid": True,
93
+ "grid.alpha": 0.3,
94
+ "axes.spines.top": False,
95
+ "axes.spines.right": False,
96
+ })
97
+
98
+
99
+ def _save_fig(fig: plt.Figure, output_dir: Path, name: str) -> None:
100
+ """Save figure in both PNG and PDF formats."""
101
+ fig.savefig(output_dir / f"{name}.png", format="png")
102
+ fig.savefig(output_dir / f"{name}.pdf", format="pdf")
103
+ plt.close(fig)
104
+ print(f" Saved: {name}.png / .pdf")
105
+
106
+
107
+ def _get_color(name: str) -> str:
108
+ return MODEL_COLORS.get(name, "#333333")
109
+
110
+
111
+ # ═══════════════════════════════════════════════════════════════════════
112
+ # Figure 1: ROC Curves
113
+ # ═══════════════════════════════════════════════════════════════════════
114
+
115
+ def plot_roc_curves(
116
+ results: dict[str, Any],
117
+ output_dir: Path,
118
+ ) -> None:
119
+ """Plot ROC curves for all models on the same axes."""
120
+ fig, ax = plt.subplots(figsize=(7, 6))
121
+
122
+ for name, data in results.items():
123
+ if name.startswith("_"):
124
+ continue
125
+ y_true = np.array(data["y_true"])
126
+ y_prob = np.array(data["y_prob"])
127
+
128
+ fpr, tpr, _ = roc_curve(y_true, y_prob)
129
+ roc_auc = auc(fpr, tpr)
130
+
131
+ ax.plot(
132
+ fpr, tpr,
133
+ color=_get_color(name),
134
+ linewidth=2,
135
+ label=f"{name} (AUC = {roc_auc:.3f})",
136
+ )
137
+
138
+ # Diagonal reference
139
+ ax.plot([0, 1], [0, 1], "k--", linewidth=1, alpha=0.5, label="Random (AUC = 0.500)")
140
+
141
+ ax.set_xlim([-0.02, 1.02])
142
+ ax.set_ylim([-0.02, 1.02])
143
+ ax.set_xlabel("False Positive Rate")
144
+ ax.set_ylabel("True Positive Rate")
145
+ ax.set_title("ROC Curves — AURIS Model Comparison")
146
+ ax.legend(loc="lower right", framealpha=0.9)
147
+ ax.set_aspect("equal")
148
+
149
+ _save_fig(fig, output_dir, "fig1_roc_curves")
150
+
151
+
152
+ # ═══════════════════════════════════════════════════════════════════════
153
+ # Figure 2: Precision-Recall Curves
154
+ # ═══════════════════════════════════════════════════════════════════════
155
+
156
+ def plot_pr_curves(
157
+ results: dict[str, Any],
158
+ output_dir: Path,
159
+ ) -> None:
160
+ """Plot Precision-Recall curves for all models."""
161
+ fig, ax = plt.subplots(figsize=(7, 6))
162
+
163
+ for name, data in results.items():
164
+ if name.startswith("_"):
165
+ continue
166
+ y_true = np.array(data["y_true"])
167
+ y_prob = np.array(data["y_prob"])
168
+
169
+ precision, recall, _ = precision_recall_curve(y_true, y_prob)
170
+ pr_auc = auc(recall, precision)
171
+
172
+ ax.plot(
173
+ recall, precision,
174
+ color=_get_color(name),
175
+ linewidth=2,
176
+ label=f"{name} (AP = {pr_auc:.3f})",
177
+ )
178
+
179
+ # Baseline: proportion of positives
180
+ all_y = []
181
+ for name, data in results.items():
182
+ if not name.startswith("_"):
183
+ all_y = data["y_true"]
184
+ break
185
+ baseline = np.mean(all_y) if all_y else 0.5
186
+ ax.axhline(y=baseline, color="k", linestyle="--", linewidth=1, alpha=0.5,
187
+ label=f"Baseline ({baseline:.2f})")
188
+
189
+ ax.set_xlim([-0.02, 1.02])
190
+ ax.set_ylim([-0.02, 1.05])
191
+ ax.set_xlabel("Recall")
192
+ ax.set_ylabel("Precision")
193
+ ax.set_title("Precision-Recall Curves — AURIS Model Comparison")
194
+ ax.legend(loc="lower left", framealpha=0.9)
195
+
196
+ _save_fig(fig, output_dir, "fig2_pr_curves")
197
+
198
+
199
+ # ═══════════════════════════════════════════════════════════════════════
200
+ # Figure 3: Confusion Matrices
201
+ # ═══════════════════════════════════════════════════════════════════════
202
+
203
+ def plot_confusion_matrices(
204
+ results: dict[str, Any],
205
+ output_dir: Path,
206
+ ) -> None:
207
+ """Plot confusion matrix heatmap for each model."""
208
+ model_names = [k for k in results if not k.startswith("_")]
209
+ n_models = len(model_names)
210
+ cols = min(3, n_models)
211
+ rows = (n_models + cols - 1) // cols
212
+
213
+ fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4.5 * rows))
214
+ if n_models == 1:
215
+ axes = np.array([axes])
216
+ axes = axes.flatten()
217
+
218
+ cmap = LinearSegmentedColormap.from_list("auris", ["#ffffff", AURIS_BLUE])
219
+
220
+ for idx, name in enumerate(model_names):
221
+ ax = axes[idx]
222
+ data = results[name]
223
+ y_true = np.array(data["y_true"])
224
+ y_pred = np.array(data["y_pred"])
225
+
226
+ cm = confusion_matrix(y_true, y_pred)
227
+ cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True)
228
+
229
+ im = ax.imshow(cm_norm, interpolation="nearest", cmap=cmap, vmin=0, vmax=1)
230
+
231
+ # Annotate cells with count and percentage
232
+ for i in range(2):
233
+ for j in range(2):
234
+ color = "white" if cm_norm[i, j] > 0.6 else "black"
235
+ ax.text(j, i, f"{cm[i, j]}\n({cm_norm[i, j]:.1%})",
236
+ ha="center", va="center", fontsize=12, color=color,
237
+ fontweight="bold")
238
+
239
+ ax.set_xticks([0, 1])
240
+ ax.set_yticks([0, 1])
241
+ ax.set_xticklabels(["Human", "AI"])
242
+ ax.set_yticklabels(["Human", "AI"])
243
+ ax.set_xlabel("Predicted")
244
+ ax.set_ylabel("Actual")
245
+ ax.set_title(name, fontsize=11)
246
+
247
+ # Hide unused axes
248
+ for idx in range(n_models, len(axes)):
249
+ axes[idx].set_visible(False)
250
+
251
+ fig.suptitle("Confusion Matrices — AURIS Model Comparison", fontsize=14, y=1.02)
252
+ fig.tight_layout()
253
+ _save_fig(fig, output_dir, "fig3_confusion_matrices")
254
+
255
+
256
+ # ═══════════════════════════════════════════════════════════════════════
257
+ # Figure 4: Model Comparison Bar Chart
258
+ # ═══════════════════════════════════════════════════════════════════════
259
+
260
+ def plot_model_comparison(
261
+ results: dict[str, Any],
262
+ output_dir: Path,
263
+ ) -> None:
264
+ """Bar chart comparing Accuracy, F1, Precision, Recall, AUC across models."""
265
+ model_names = [k for k in results if not k.startswith("_")]
266
+ metrics = ["accuracy", "precision", "recall", "f1", "roc_auc"]
267
+ metric_labels = ["Accuracy", "Precision", "Recall", "F1 Score", "ROC-AUC"]
268
+
269
+ x = np.arange(len(model_names))
270
+ width = 0.15
271
+ metric_colors = ["#4363d8", "#3cb44b", "#e6194b", "#f58231", "#911eb4"]
272
+
273
+ fig, ax = plt.subplots(figsize=(max(10, len(model_names) * 2), 6))
274
+
275
+ for i, (metric, label, color) in enumerate(zip(metrics, metric_labels, metric_colors)):
276
+ values = []
277
+ for name in model_names:
278
+ val = results[name].get(metric, 0)
279
+ values.append(val if val is not None else 0)
280
+ bars = ax.bar(x + i * width, values, width, label=label, color=color, alpha=0.85)
281
+
282
+ # Value labels on bars
283
+ for bar, val in zip(bars, values):
284
+ ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,
285
+ f"{val:.2f}", ha="center", va="bottom", fontsize=7)
286
+
287
+ ax.set_xlabel("Model")
288
+ ax.set_ylabel("Score")
289
+ ax.set_title("Model Performance Comparison — AURIS")
290
+ ax.set_xticks(x + width * 2)
291
+ ax.set_xticklabels(model_names, rotation=25, ha="right")
292
+ ax.set_ylim([0, 1.12])
293
+ ax.legend(loc="upper right", ncol=5, framealpha=0.9)
294
+ ax.yaxis.set_major_formatter(mticker.PercentFormatter(1.0))
295
+
296
+ fig.tight_layout()
297
+ _save_fig(fig, output_dir, "fig4_model_comparison")
298
+
299
+
300
+ # ═══════════════════════════════════════════════════════════════════════
301
+ # Figure 5: Feature Importance
302
+ # ═══════════════════════════════════════════════════════════════════════
303
+
304
+ def plot_feature_importance(
305
+ results: dict[str, Any],
306
+ output_dir: Path,
307
+ top_n: int = 20,
308
+ ) -> None:
309
+ """Horizontal bar chart of top-N feature importances."""
310
+ importance = results.get("_feature_importance")
311
+ if not importance:
312
+ print(" Skipping feature importance — no data available")
313
+ return
314
+
315
+ # Sort and take top N
316
+ sorted_features = sorted(importance.items(), key=lambda x: x[1], reverse=True)[:top_n]
317
+ names = [f[0] for f in reversed(sorted_features)]
318
+ values = [f[1] for f in reversed(sorted_features)]
319
+
320
+ fig, ax = plt.subplots(figsize=(8, max(6, top_n * 0.35)))
321
+
322
+ colors = plt.cm.Blues(np.linspace(0.3, 0.9, len(names)))
323
+ bars = ax.barh(names, values, color=colors, edgecolor="white", linewidth=0.5)
324
+
325
+ # Value labels
326
+ for bar, val in zip(bars, values):
327
+ ax.text(bar.get_width() + 0.002, bar.get_y() + bar.get_height() / 2,
328
+ f"{val:.4f}", ha="left", va="center", fontsize=9)
329
+
330
+ ax.set_xlabel("Relative Importance")
331
+ ax.set_title(f"Top {top_n} Feature Importances — AURIS ({results.get('_best_model', '')})")
332
+ ax.set_xlim([0, max(values) * 1.15])
333
+
334
+ fig.tight_layout()
335
+ _save_fig(fig, output_dir, "fig5_feature_importance")
336
+
337
+
338
+ # ═══════════════════════════════════════════════════════════════════════
339
+ # Figure 6: Feature Correlation Heatmap
340
+ # ═══════════════════════════════════════════════════════════════════════
341
+
342
+ def plot_correlation_heatmap(
343
+ features_csv: Path,
344
+ output_dir: Path,
345
+ ) -> None:
346
+ """Correlation heatmap of all features."""
347
+ X, y, feature_cols = _load_features_with_names(features_csv)
348
+
349
+ corr = np.corrcoef(X.T)
350
+
351
+ fig, ax = plt.subplots(figsize=(max(12, len(feature_cols) * 0.4),
352
+ max(10, len(feature_cols) * 0.35)))
353
+
354
+ cmap = "RdBu_r" if HAS_SEABORN else "coolwarm"
355
+
356
+ if HAS_SEABORN:
357
+ sns.heatmap(
358
+ corr, xticklabels=feature_cols, yticklabels=feature_cols,
359
+ cmap=cmap, center=0, vmin=-1, vmax=1,
360
+ square=True, linewidths=0.5, ax=ax,
361
+ cbar_kws={"shrink": 0.8, "label": "Pearson Correlation"},
362
+ )
363
+ else:
364
+ im = ax.imshow(corr, cmap=cmap, vmin=-1, vmax=1, aspect="auto")
365
+ ax.set_xticks(range(len(feature_cols)))
366
+ ax.set_yticks(range(len(feature_cols)))
367
+ ax.set_xticklabels(feature_cols, rotation=90, fontsize=7)
368
+ ax.set_yticklabels(feature_cols, fontsize=7)
369
+ fig.colorbar(im, ax=ax, shrink=0.8, label="Pearson Correlation")
370
+
371
+ ax.set_title("Feature Correlation Matrix — AURIS", fontsize=14, pad=20)
372
+
373
+ fig.tight_layout()
374
+ _save_fig(fig, output_dir, "fig6_correlation_heatmap")
375
+
376
+
377
+ # ═══════════════════════════════════════════════════════════════════════
378
+ # Figure 7: Feature Distribution (Violin / Box Plots)
379
+ # ═══════════════════════════════════════════════════════════════════════
380
+
381
+ def plot_feature_distributions(
382
+ features_csv: Path,
383
+ output_dir: Path,
384
+ results: dict[str, Any] | None = None,
385
+ top_n: int = 12,
386
+ ) -> None:
387
+ """Violin plots showing feature distributions for AI vs Human."""
388
+ X, y, feature_cols = _load_features_with_names(features_csv)
389
+
390
+ # Select top features by importance, or first N
391
+ if results and "_feature_importance" in results:
392
+ sorted_feats = sorted(
393
+ results["_feature_importance"].items(),
394
+ key=lambda x: x[1], reverse=True,
395
+ )
396
+ selected = [f[0] for f in sorted_feats[:top_n] if f[0] in feature_cols]
397
+ else:
398
+ selected = feature_cols[:top_n]
399
+
400
+ n_features = len(selected)
401
+ cols = 3
402
+ rows = (n_features + cols - 1) // cols
403
+
404
+ fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 3.5 * rows))
405
+ axes = axes.flatten()
406
+
407
+ for idx, feat_name in enumerate(selected):
408
+ ax = axes[idx]
409
+ col_idx = feature_cols.index(feat_name)
410
+
411
+ human_vals = X[y == 0, col_idx]
412
+ ai_vals = X[y == 1, col_idx]
413
+
414
+ if HAS_SEABORN:
415
+ data_list = []
416
+ for val in human_vals:
417
+ data_list.append({"Feature": feat_name, "Value": val, "Class": "Human"})
418
+ for val in ai_vals:
419
+ data_list.append({"Feature": feat_name, "Value": val, "Class": "AI"})
420
+
421
+ import pandas as pd
422
+ df = pd.DataFrame(data_list)
423
+ sns.violinplot(
424
+ data=df, x="Class", y="Value",
425
+ palette={"Human": AURIS_BLUE, "AI": AURIS_RED},
426
+ ax=ax, inner="quartile", linewidth=1,
427
+ )
428
+ else:
429
+ parts = ax.violinplot(
430
+ [human_vals, ai_vals],
431
+ positions=[0, 1],
432
+ showmeans=True, showmedians=True,
433
+ )
434
+ ax.set_xticks([0, 1])
435
+ ax.set_xticklabels(["Human", "AI"])
436
+
437
+ ax.set_title(feat_name, fontsize=10)
438
+ ax.set_xlabel("")
439
+
440
+ for idx in range(n_features, len(axes)):
441
+ axes[idx].set_visible(False)
442
+
443
+ fig.suptitle("Feature Distributions — AI vs Human", fontsize=14, y=1.01)
444
+ fig.tight_layout()
445
+ _save_fig(fig, output_dir, "fig7_feature_distributions")
446
+
447
+
448
+ # ═══════════════════════════════════════════════════════════════════════
449
+ # Figure 8: Training Summary Table (LaTeX)
450
+ # ═══════════════════════════════════════════════════════════════════════
451
+
452
+ def generate_latex_table(
453
+ results: dict[str, Any],
454
+ output_dir: Path,
455
+ ) -> None:
456
+ """Generate LaTeX-ready comparison table."""
457
+ model_names = [k for k in results if not k.startswith("_")]
458
+ best_model = results.get("_best_model", "")
459
+
460
+ lines = [
461
+ r"\begin{table}[htbp]",
462
+ r"\centering",
463
+ r"\caption{AURIS Model Performance Comparison}",
464
+ r"\label{tab:model-comparison}",
465
+ r"\begin{tabular}{lccccc}",
466
+ r"\toprule",
467
+ r"Model & Accuracy & Precision & Recall & F1 & ROC-AUC \\",
468
+ r"\midrule",
469
+ ]
470
+
471
+ for name in model_names:
472
+ data = results[name]
473
+ acc = data.get("accuracy", 0)
474
+ prec = data.get("precision", 0)
475
+ rec = data.get("recall", 0)
476
+ f1 = data.get("f1", 0)
477
+ roc = data.get("roc_auc", 0)
478
+
479
+ # Bold the best model
480
+ prefix = r"\textbf{" if name == best_model else ""
481
+ suffix = "}" if name == best_model else ""
482
+
483
+ row = (
484
+ f" {prefix}{name}{suffix} & "
485
+ f"{prefix}{acc:.4f}{suffix} & "
486
+ f"{prefix}{prec:.4f}{suffix} & "
487
+ f"{prefix}{rec:.4f}{suffix} & "
488
+ f"{prefix}{f1:.4f}{suffix} & "
489
+ f"{prefix}{roc:.4f}{suffix} \\\\"
490
+ )
491
+ lines.append(row)
492
+
493
+ lines.extend([
494
+ r"\bottomrule",
495
+ r"\end{tabular}",
496
+ r"\end{table}",
497
+ ])
498
+
499
+ latex_content = "\n".join(lines)
500
+ table_path = output_dir / "table1_model_comparison.tex"
501
+ table_path.write_text(latex_content, encoding="utf-8")
502
+ print(f" Saved: table1_model_comparison.tex")
503
+
504
+ # Also save as markdown for README
505
+ md_lines = [
506
+ "| Model | Accuracy | Precision | Recall | F1 | ROC-AUC |",
507
+ "|-------|----------|-----------|--------|-----|---------|",
508
+ ]
509
+ for name in model_names:
510
+ data = results[name]
511
+ bold = "**" if name == best_model else ""
512
+ md_lines.append(
513
+ f"| {bold}{name}{bold} | "
514
+ f"{data.get('accuracy', 0):.4f} | "
515
+ f"{data.get('precision', 0):.4f} | "
516
+ f"{data.get('recall', 0):.4f} | "
517
+ f"{data.get('f1', 0):.4f} | "
518
+ f"{data.get('roc_auc', 0):.4f} |"
519
+ )
520
+
521
+ md_path = output_dir / "table1_model_comparison.md"
522
+ md_path.write_text("\n".join(md_lines), encoding="utf-8")
523
+ print(f" Saved: table1_model_comparison.md")
524
+
525
+
526
+ # ═════════════════════════════════════════════���═════════════════════════
527
+ # Utilities
528
+ # ═══════════════════════════════════════════════════════════════════════
529
+
530
+ def _load_features_with_names(
531
+ features_csv: Path,
532
+ ) -> tuple[np.ndarray, np.ndarray, list[str]]:
533
+ """Load features CSV returning X, y, and column names."""
534
+ rows = []
535
+ labels = []
536
+
537
+ with open(features_csv, "r", encoding="utf-8") as f:
538
+ reader = csv.DictReader(f)
539
+ feature_cols = [
540
+ c for c in reader.fieldnames
541
+ if c not in ("file_path", "label_int")
542
+ ]
543
+ for row in reader:
544
+ feat_values = []
545
+ for col in feature_cols:
546
+ try:
547
+ feat_values.append(float(row[col]))
548
+ except (ValueError, KeyError):
549
+ feat_values.append(0.0)
550
+ rows.append(feat_values)
551
+ labels.append(int(row["label_int"]))
552
+
553
+ X = np.nan_to_num(np.array(rows, dtype=np.float32), nan=0.0)
554
+ y = np.array(labels, dtype=np.int32)
555
+
556
+ return X, y, feature_cols
557
+
558
+
559
+ # ═══════════════════════════════════════════════════════════════════════
560
+ # Main entry point
561
+ # ═══════════════════════════════════════════════════════════════════════
562
+
563
+ def generate_all_figures(
564
+ results_path: str | Path,
565
+ features_csv: str | Path,
566
+ output_dir: str | Path = "figures",
567
+ ) -> None:
568
+ """Generate all publication-quality figures."""
569
+ results_path = Path(results_path)
570
+ features_csv = Path(features_csv)
571
+ output_dir = Path(output_dir)
572
+ output_dir.mkdir(parents=True, exist_ok=True)
573
+
574
+ print(f"\nLoading results from {results_path}...")
575
+ with open(results_path, "r") as f:
576
+ saved_results = json.load(f)
577
+
578
+ # We need y_true/y_pred/y_prob arrays — check if they're stored
579
+ # If not in the saved JSON, we can't plot ROC/PR curves
580
+ has_arrays = any(
581
+ "y_true" in v for k, v in saved_results.items()
582
+ if isinstance(v, dict) and not k.startswith("_")
583
+ )
584
+
585
+ print(f"\nGenerating figures in {output_dir}/...\n")
586
+
587
+ if has_arrays:
588
+ print("[1/8] ROC Curves")
589
+ plot_roc_curves(saved_results, output_dir)
590
+
591
+ print("[2/8] Precision-Recall Curves")
592
+ plot_pr_curves(saved_results, output_dir)
593
+
594
+ print("[3/8] Confusion Matrices")
595
+ plot_confusion_matrices(saved_results, output_dir)
596
+ else:
597
+ print("[1-3/8] Skipping ROC/PR/CM — no per-sample predictions stored")
598
+
599
+ print("[4/8] Model Comparison Bar Chart")
600
+ plot_model_comparison(saved_results, output_dir)
601
+
602
+ print("[5/8] Feature Importance")
603
+ plot_feature_importance(saved_results, output_dir)
604
+
605
+ print("[6/8] Feature Correlation Heatmap")
606
+ plot_correlation_heatmap(features_csv, output_dir)
607
+
608
+ print("[7/8] Feature Distributions (AI vs Human)")
609
+ plot_feature_distributions(features_csv, output_dir, saved_results)
610
+
611
+ print("[8/8] LaTeX / Markdown Tables")
612
+ generate_latex_table(saved_results, output_dir)
613
+
614
+ print(f"\nAll figures saved to {output_dir}/")
615
+
616
+
617
+ if __name__ == "__main__":
618
+ parser = argparse.ArgumentParser(description="AURIS Visualization Pipeline")
619
+ parser.add_argument(
620
+ "--results", default="models/training_results.json",
621
+ help="Path to training_results.json",
622
+ )
623
+ parser.add_argument(
624
+ "--features", default="data/training/features.csv",
625
+ help="Path to features.csv",
626
+ )
627
+ parser.add_argument(
628
+ "--output", default="figures",
629
+ help="Output directory for figures",
630
+ )
631
+ args = parser.parse_args()
632
+ generate_all_figures(args.results, args.features, args.output)