Spaces:
Sleeping
Sleeping
| """ | |
| evaluate.py β Metrics, Confusion Matrix, Error Analysis, ROC-AUC | |
| """ | |
| import os | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import matplotlib.gridspec as gridspec | |
| import seaborn as sns | |
| from sklearn.metrics import ( | |
| accuracy_score, | |
| precision_score, | |
| recall_score, | |
| f1_score, | |
| roc_auc_score, | |
| roc_curve, | |
| confusion_matrix, | |
| classification_report, | |
| ) | |
| # Style | |
| plt.rcParams.update({ | |
| "figure.facecolor": "#0f0f1a", | |
| "axes.facecolor": "#1a1a2e", | |
| "axes.edgecolor": "#444", | |
| "axes.labelcolor": "white", | |
| "text.color": "white", | |
| "xtick.color": "white", | |
| "ytick.color": "white", | |
| "grid.color": "#333", | |
| "font.family": "DejaVu Sans", | |
| }) | |
| RESULTS_DIR = "results" | |
| PLOTS_DIR = os.path.join(RESULTS_DIR, "confusion_matrices") | |
| os.makedirs(PLOTS_DIR, exist_ok=True) | |
| # Track all model results for final comparison | |
| ALL_RESULTS = {} | |
| def evaluate_model(y_true, y_pred, y_proba=None, model_name: str = "Model", | |
| split: str = "test", save_plots: bool = False, | |
| X_texts=None) -> dict: | |
| """ | |
| Full evaluation suite: | |
| - Accuracy, Precision, Recall, F1, ROC-AUC | |
| - Confusion matrix (plotted) | |
| - ROC curve (plotted) | |
| - Error analysis (misclassified samples) | |
| Args: | |
| y_true: True labels. | |
| y_pred: Predicted labels. | |
| y_proba: Predicted positive-class probabilities (for ROC-AUC). | |
| model_name: Name label for this model. | |
| split: 'val' or 'test'. | |
| save_plots: Save figures to results/ folder. | |
| X_texts: Optional raw texts for error analysis. | |
| Returns: | |
| dict with all metric values. | |
| """ | |
| metrics = { | |
| "model": model_name, | |
| "split": split, | |
| "accuracy": accuracy_score(y_true, y_pred), | |
| "precision": precision_score(y_true, y_pred, average="binary", zero_division=0), | |
| "recall": recall_score(y_true, y_pred, average="binary", zero_division=0), | |
| "f1": f1_score(y_true, y_pred, average="binary", zero_division=0), | |
| "roc_auc": roc_auc_score(y_true, y_proba) if y_proba is not None else None, | |
| } | |
| # Console report | |
| print(f"\n{'β'*50}") | |
| print(f" π {model_name} β {split.upper()} SET") | |
| print(f"{'β'*50}") | |
| print(f" Accuracy : {metrics['accuracy']:.4f}") | |
| print(f" Precision : {metrics['precision']:.4f}") | |
| print(f" Recall : {metrics['recall']:.4f}") | |
| print(f" F1-Score : {metrics['f1']:.4f}") | |
| if metrics["roc_auc"]: | |
| print(f" ROC-AUC : {metrics['roc_auc']:.4f}") | |
| print(f"{'β'*50}") | |
| print(classification_report(y_true, y_pred, | |
| target_names=["Negative", "Positive"])) | |
| if save_plots: | |
| _plot_confusion_matrix(y_true, y_pred, model_name, split) | |
| if y_proba is not None: | |
| _plot_roc_curve(y_true, y_proba, model_name, split, metrics["roc_auc"]) | |
| if split == "test": | |
| ALL_RESULTS[model_name] = metrics | |
| _save_metrics_csv() | |
| if X_texts is not None and split == "test": | |
| do_error_analysis(y_true, y_pred, y_proba, X_texts, model_name) | |
| return metrics | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Confusion Matrix | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _plot_confusion_matrix(y_true, y_pred, model_name: str, split: str): | |
| cm = confusion_matrix(y_true, y_pred) | |
| cm_pct = cm.astype(float) / cm.sum(axis=1, keepdims=True) | |
| fig, ax = plt.subplots(figsize=(7, 6)) | |
| fig.patch.set_facecolor("#0f0f1a") | |
| ax.set_facecolor("#1a1a2e") | |
| sns.heatmap(cm, annot=False, fmt="d", cmap="Blues", ax=ax, | |
| linewidths=0.5, linecolor="#333", | |
| cbar_kws={"shrink": 0.8}) | |
| # Annotate cells with count + percentage | |
| labels = [["TN", "FP"], ["FN", "TP"]] | |
| for i in range(2): | |
| for j in range(2): | |
| ax.text(j + 0.5, i + 0.35, | |
| f"{labels[i][j]}\n{cm[i][j]:,}", | |
| ha="center", va="center", fontsize=14, | |
| color="white", fontweight="bold") | |
| ax.text(j + 0.5, i + 0.65, | |
| f"({cm_pct[i][j]:.1%})", | |
| ha="center", va="center", fontsize=11, color="#aaa") | |
| ax.set_xticklabels(["Negative", "Positive"], fontsize=12) | |
| ax.set_yticklabels(["Negative", "Positive"], fontsize=12, rotation=0) | |
| ax.set_xlabel("Predicted Label", fontsize=13, labelpad=10) | |
| ax.set_ylabel("True Label", fontsize=13, labelpad=10) | |
| ax.set_title(f"Confusion Matrix β {model_name}\n({split} set)", | |
| fontsize=14, fontweight="bold", pad=15) | |
| plt.tight_layout() | |
| safe_name = model_name.replace(" ", "_").replace("/", "_") | |
| path = os.path.join(PLOTS_DIR, f"cm_{safe_name}_{split}.png") | |
| plt.savefig(path, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor()) | |
| plt.close() | |
| print(f" π Confusion matrix saved β {path}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ROC Curve | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _plot_roc_curve(y_true, y_proba, model_name: str, split: str, auc: float): | |
| fpr, tpr, _ = roc_curve(y_true, y_proba) | |
| fig, ax = plt.subplots(figsize=(7, 6)) | |
| fig.patch.set_facecolor("#0f0f1a") | |
| ax.set_facecolor("#1a1a2e") | |
| ax.plot(fpr, tpr, color="#6c63ff", lw=2.5, | |
| label=f"AUC = {auc:.4f}") | |
| ax.plot([0, 1], [0, 1], "r--", lw=1.5, label="Random Classifier") | |
| ax.fill_between(fpr, tpr, alpha=0.15, color="#6c63ff") | |
| ax.set_xlabel("False Positive Rate", fontsize=12) | |
| ax.set_ylabel("True Positive Rate", fontsize=12) | |
| ax.set_title(f"ROC Curve β {model_name} ({split})", fontsize=13, fontweight="bold") | |
| ax.legend(fontsize=11, loc="lower right") | |
| ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| safe_name = model_name.replace(" ", "_").replace("/", "_") | |
| path = os.path.join(PLOTS_DIR, f"roc_{safe_name}_{split}.png") | |
| plt.savefig(path, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor()) | |
| plt.close() | |
| print(f" π ROC curve saved β {path}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Error Analysis | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def do_error_analysis(y_true, y_pred, y_proba, X_texts, | |
| model_name: str, n_samples: int = 30): | |
| """ | |
| Identify and save misclassified samples with confidence scores. | |
| Outputs: | |
| - results/error_analysis_{model_name}.csv | |
| - Console summary of error patterns | |
| """ | |
| y_true = np.array(y_true) | |
| y_pred = np.array(y_pred) | |
| y_proba = np.array(y_proba) if y_proba is not None else np.zeros(len(y_true)) | |
| X_texts = np.array(X_texts) | |
| misclassified_mask = y_true != y_pred | |
| n_errors = misclassified_mask.sum() | |
| print(f"\nπ Error Analysis β {model_name}") | |
| print(f" Total errors: {n_errors}/{len(y_true)} " | |
| f"({n_errors/len(y_true)*100:.1f}%)") | |
| wrong_texts = X_texts[misclassified_mask] | |
| wrong_true = y_true[misclassified_mask] | |
| wrong_pred = y_pred[misclassified_mask] | |
| wrong_conf = y_proba[misclassified_mask] | |
| # Error types | |
| fp_mask = (wrong_true == 0) & (wrong_pred == 1) | |
| fn_mask = (wrong_true == 1) & (wrong_pred == 0) | |
| print(f" False Positives (negβpos): {fp_mask.sum()}") | |
| print(f" False Negatives (posβneg): {fn_mask.sum()}") | |
| # High-confidence mistakes (model very wrong) | |
| high_conf_errors = np.abs(wrong_conf - 0.5) > 0.3 | |
| print(f" High-confidence mistakes: {high_conf_errors.sum()}") | |
| # Build DataFrame | |
| error_df = pd.DataFrame({ | |
| "text": wrong_texts, | |
| "true_label": ["Positive" if l == 1 else "Negative" for l in wrong_true], | |
| "pred_label": ["Positive" if l == 1 else "Negative" for l in wrong_pred], | |
| "error_type": ["FP" if fp else "FN" for fp, fn in zip(fp_mask, fn_mask) | |
| for _ in [1]], | |
| "confidence": wrong_conf, | |
| "high_confidence_mistake": high_conf_errors, | |
| }) | |
| # Trim text for readability | |
| error_df["text_preview"] = error_df["text"].str[:200] | |
| os.makedirs(RESULTS_DIR, exist_ok=True) | |
| safe_name = model_name.replace(" ", "_").replace("/", "_") | |
| out_path = os.path.join(RESULTS_DIR, f"error_analysis_{safe_name}.csv") | |
| error_df.to_csv(out_path, index=False) | |
| print(f" πΎ Error analysis saved β {out_path}") | |
| # Print most confident mistakes | |
| print(f"\n Top {min(5, n_errors)} most confident mistakes:") | |
| top = error_df.sort_values("confidence", ascending=False).head(5) | |
| for _, row in top.iterrows(): | |
| print(f" [{row['error_type']}] conf={row['confidence']:.3f} | " | |
| f"'{row['text_preview'][:80]}...'") | |
| return error_df | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Comparison Chart | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_model_comparison(): | |
| """Plot a side-by-side comparison bar chart of all evaluated models.""" | |
| if not ALL_RESULTS: | |
| print("β οΈ No results to compare yet.") | |
| return | |
| df = pd.DataFrame(ALL_RESULTS).T | |
| metrics_to_plot = ["accuracy", "precision", "recall", "f1", "roc_auc"] | |
| df = df[metrics_to_plot].astype(float) | |
| fig, ax = plt.subplots(figsize=(11, 6)) | |
| fig.patch.set_facecolor("#0f0f1a") | |
| ax.set_facecolor("#1a1a2e") | |
| x = np.arange(len(metrics_to_plot)) | |
| width = 0.22 | |
| colors = ["#6c63ff", "#ff6584", "#43aa8b"] | |
| for i, (model, row) in enumerate(df.iterrows()): | |
| bars = ax.bar(x + i * width, row.values, width, | |
| label=model, color=colors[i % len(colors)], | |
| alpha=0.9, edgecolor="white", linewidth=0.5) | |
| for bar in bars: | |
| h = bar.get_height() | |
| ax.text(bar.get_x() + bar.get_width() / 2, h + 0.005, | |
| f"{h:.3f}", ha="center", va="bottom", | |
| fontsize=8, color="white") | |
| ax.set_xticks(x + width) | |
| ax.set_xticklabels([m.replace("_", " ").upper() for m in metrics_to_plot], | |
| fontsize=11) | |
| ax.set_ylim(0.80, 1.01) | |
| ax.set_ylabel("Score", fontsize=12) | |
| ax.set_title("Model Comparison β Sentiment Analysis (IMDB)", fontsize=14, | |
| fontweight="bold", pad=15) | |
| ax.legend(fontsize=10, loc="lower right") | |
| ax.grid(True, axis="y", alpha=0.3) | |
| plt.tight_layout() | |
| path = os.path.join(RESULTS_DIR, "model_comparison.png") | |
| plt.savefig(path, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor()) | |
| plt.close() | |
| print(f"\nπ Comparison chart saved β {path}") | |
| def _save_metrics_csv(): | |
| """Persist all model metrics to CSV.""" | |
| if not ALL_RESULTS: | |
| return | |
| df = pd.DataFrame(ALL_RESULTS).T | |
| path = os.path.join(RESULTS_DIR, "metrics_summary.csv") | |
| df.to_csv(path) | |
| print(f" πΎ Metrics summary β {path}") | |