""" evaluate.py β€” Metrics, Confusion Matrix, Error Analysis, ROC-AUC """ import os import numpy as np import pandas as pd import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec import seaborn as sns from sklearn.metrics import ( accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, roc_curve, confusion_matrix, classification_report, ) # Style plt.rcParams.update({ "figure.facecolor": "#0f0f1a", "axes.facecolor": "#1a1a2e", "axes.edgecolor": "#444", "axes.labelcolor": "white", "text.color": "white", "xtick.color": "white", "ytick.color": "white", "grid.color": "#333", "font.family": "DejaVu Sans", }) RESULTS_DIR = "results" PLOTS_DIR = os.path.join(RESULTS_DIR, "confusion_matrices") os.makedirs(PLOTS_DIR, exist_ok=True) # Track all model results for final comparison ALL_RESULTS = {} def evaluate_model(y_true, y_pred, y_proba=None, model_name: str = "Model", split: str = "test", save_plots: bool = False, X_texts=None) -> dict: """ Full evaluation suite: - Accuracy, Precision, Recall, F1, ROC-AUC - Confusion matrix (plotted) - ROC curve (plotted) - Error analysis (misclassified samples) Args: y_true: True labels. y_pred: Predicted labels. y_proba: Predicted positive-class probabilities (for ROC-AUC). model_name: Name label for this model. split: 'val' or 'test'. save_plots: Save figures to results/ folder. X_texts: Optional raw texts for error analysis. Returns: dict with all metric values. """ metrics = { "model": model_name, "split": split, "accuracy": accuracy_score(y_true, y_pred), "precision": precision_score(y_true, y_pred, average="binary", zero_division=0), "recall": recall_score(y_true, y_pred, average="binary", zero_division=0), "f1": f1_score(y_true, y_pred, average="binary", zero_division=0), "roc_auc": roc_auc_score(y_true, y_proba) if y_proba is not None else None, } # Console report print(f"\n{'─'*50}") print(f" πŸ“Š {model_name} β€” {split.upper()} SET") print(f"{'─'*50}") print(f" Accuracy : {metrics['accuracy']:.4f}") print(f" Precision : {metrics['precision']:.4f}") print(f" Recall : {metrics['recall']:.4f}") print(f" F1-Score : {metrics['f1']:.4f}") if metrics["roc_auc"]: print(f" ROC-AUC : {metrics['roc_auc']:.4f}") print(f"{'─'*50}") print(classification_report(y_true, y_pred, target_names=["Negative", "Positive"])) if save_plots: _plot_confusion_matrix(y_true, y_pred, model_name, split) if y_proba is not None: _plot_roc_curve(y_true, y_proba, model_name, split, metrics["roc_auc"]) if split == "test": ALL_RESULTS[model_name] = metrics _save_metrics_csv() if X_texts is not None and split == "test": do_error_analysis(y_true, y_pred, y_proba, X_texts, model_name) return metrics # ────────────────────────────────────────────── # Confusion Matrix # ────────────────────────────────────────────── def _plot_confusion_matrix(y_true, y_pred, model_name: str, split: str): cm = confusion_matrix(y_true, y_pred) cm_pct = cm.astype(float) / cm.sum(axis=1, keepdims=True) fig, ax = plt.subplots(figsize=(7, 6)) fig.patch.set_facecolor("#0f0f1a") ax.set_facecolor("#1a1a2e") sns.heatmap(cm, annot=False, fmt="d", cmap="Blues", ax=ax, linewidths=0.5, linecolor="#333", cbar_kws={"shrink": 0.8}) # Annotate cells with count + percentage labels = [["TN", "FP"], ["FN", "TP"]] for i in range(2): for j in range(2): ax.text(j + 0.5, i + 0.35, f"{labels[i][j]}\n{cm[i][j]:,}", ha="center", va="center", fontsize=14, color="white", fontweight="bold") ax.text(j + 0.5, i + 0.65, f"({cm_pct[i][j]:.1%})", ha="center", va="center", fontsize=11, color="#aaa") ax.set_xticklabels(["Negative", "Positive"], fontsize=12) ax.set_yticklabels(["Negative", "Positive"], fontsize=12, rotation=0) ax.set_xlabel("Predicted Label", fontsize=13, labelpad=10) ax.set_ylabel("True Label", fontsize=13, labelpad=10) ax.set_title(f"Confusion Matrix β€” {model_name}\n({split} set)", fontsize=14, fontweight="bold", pad=15) plt.tight_layout() safe_name = model_name.replace(" ", "_").replace("/", "_") path = os.path.join(PLOTS_DIR, f"cm_{safe_name}_{split}.png") plt.savefig(path, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor()) plt.close() print(f" πŸ“Š Confusion matrix saved β†’ {path}") # ────────────────────────────────────────────── # ROC Curve # ────────────────────────────────────────────── def _plot_roc_curve(y_true, y_proba, model_name: str, split: str, auc: float): fpr, tpr, _ = roc_curve(y_true, y_proba) fig, ax = plt.subplots(figsize=(7, 6)) fig.patch.set_facecolor("#0f0f1a") ax.set_facecolor("#1a1a2e") ax.plot(fpr, tpr, color="#6c63ff", lw=2.5, label=f"AUC = {auc:.4f}") ax.plot([0, 1], [0, 1], "r--", lw=1.5, label="Random Classifier") ax.fill_between(fpr, tpr, alpha=0.15, color="#6c63ff") ax.set_xlabel("False Positive Rate", fontsize=12) ax.set_ylabel("True Positive Rate", fontsize=12) ax.set_title(f"ROC Curve β€” {model_name} ({split})", fontsize=13, fontweight="bold") ax.legend(fontsize=11, loc="lower right") ax.grid(True, alpha=0.3) plt.tight_layout() safe_name = model_name.replace(" ", "_").replace("/", "_") path = os.path.join(PLOTS_DIR, f"roc_{safe_name}_{split}.png") plt.savefig(path, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor()) plt.close() print(f" πŸ“ˆ ROC curve saved β†’ {path}") # ────────────────────────────────────────────── # Error Analysis # ────────────────────────────────────────────── def do_error_analysis(y_true, y_pred, y_proba, X_texts, model_name: str, n_samples: int = 30): """ Identify and save misclassified samples with confidence scores. Outputs: - results/error_analysis_{model_name}.csv - Console summary of error patterns """ y_true = np.array(y_true) y_pred = np.array(y_pred) y_proba = np.array(y_proba) if y_proba is not None else np.zeros(len(y_true)) X_texts = np.array(X_texts) misclassified_mask = y_true != y_pred n_errors = misclassified_mask.sum() print(f"\nπŸ” Error Analysis β€” {model_name}") print(f" Total errors: {n_errors}/{len(y_true)} " f"({n_errors/len(y_true)*100:.1f}%)") wrong_texts = X_texts[misclassified_mask] wrong_true = y_true[misclassified_mask] wrong_pred = y_pred[misclassified_mask] wrong_conf = y_proba[misclassified_mask] # Error types fp_mask = (wrong_true == 0) & (wrong_pred == 1) fn_mask = (wrong_true == 1) & (wrong_pred == 0) print(f" False Positives (negβ†’pos): {fp_mask.sum()}") print(f" False Negatives (posβ†’neg): {fn_mask.sum()}") # High-confidence mistakes (model very wrong) high_conf_errors = np.abs(wrong_conf - 0.5) > 0.3 print(f" High-confidence mistakes: {high_conf_errors.sum()}") # Build DataFrame error_df = pd.DataFrame({ "text": wrong_texts, "true_label": ["Positive" if l == 1 else "Negative" for l in wrong_true], "pred_label": ["Positive" if l == 1 else "Negative" for l in wrong_pred], "error_type": ["FP" if fp else "FN" for fp, fn in zip(fp_mask, fn_mask) for _ in [1]], "confidence": wrong_conf, "high_confidence_mistake": high_conf_errors, }) # Trim text for readability error_df["text_preview"] = error_df["text"].str[:200] os.makedirs(RESULTS_DIR, exist_ok=True) safe_name = model_name.replace(" ", "_").replace("/", "_") out_path = os.path.join(RESULTS_DIR, f"error_analysis_{safe_name}.csv") error_df.to_csv(out_path, index=False) print(f" πŸ’Ύ Error analysis saved β†’ {out_path}") # Print most confident mistakes print(f"\n Top {min(5, n_errors)} most confident mistakes:") top = error_df.sort_values("confidence", ascending=False).head(5) for _, row in top.iterrows(): print(f" [{row['error_type']}] conf={row['confidence']:.3f} | " f"'{row['text_preview'][:80]}...'") return error_df # ────────────────────────────────────────────── # Comparison Chart # ────────────────────────────────────────────── def plot_model_comparison(): """Plot a side-by-side comparison bar chart of all evaluated models.""" if not ALL_RESULTS: print("⚠️ No results to compare yet.") return df = pd.DataFrame(ALL_RESULTS).T metrics_to_plot = ["accuracy", "precision", "recall", "f1", "roc_auc"] df = df[metrics_to_plot].astype(float) fig, ax = plt.subplots(figsize=(11, 6)) fig.patch.set_facecolor("#0f0f1a") ax.set_facecolor("#1a1a2e") x = np.arange(len(metrics_to_plot)) width = 0.22 colors = ["#6c63ff", "#ff6584", "#43aa8b"] for i, (model, row) in enumerate(df.iterrows()): bars = ax.bar(x + i * width, row.values, width, label=model, color=colors[i % len(colors)], alpha=0.9, edgecolor="white", linewidth=0.5) for bar in bars: h = bar.get_height() ax.text(bar.get_x() + bar.get_width() / 2, h + 0.005, f"{h:.3f}", ha="center", va="bottom", fontsize=8, color="white") ax.set_xticks(x + width) ax.set_xticklabels([m.replace("_", " ").upper() for m in metrics_to_plot], fontsize=11) ax.set_ylim(0.80, 1.01) ax.set_ylabel("Score", fontsize=12) ax.set_title("Model Comparison β€” Sentiment Analysis (IMDB)", fontsize=14, fontweight="bold", pad=15) ax.legend(fontsize=10, loc="lower right") ax.grid(True, axis="y", alpha=0.3) plt.tight_layout() path = os.path.join(RESULTS_DIR, "model_comparison.png") plt.savefig(path, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor()) plt.close() print(f"\nπŸ“Š Comparison chart saved β†’ {path}") def _save_metrics_csv(): """Persist all model metrics to CSV.""" if not ALL_RESULTS: return df = pd.DataFrame(ALL_RESULTS).T path = os.path.join(RESULTS_DIR, "metrics_summary.csv") df.to_csv(path) print(f" πŸ’Ύ Metrics summary β†’ {path}")