Sentiment-Analysis / src /evaluate.py
najahaja's picture
Upload 26 files
c247f12 verified
"""
evaluate.py β€” Metrics, Confusion Matrix, Error Analysis, ROC-AUC
"""
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from sklearn.metrics import (
accuracy_score,
precision_score,
recall_score,
f1_score,
roc_auc_score,
roc_curve,
confusion_matrix,
classification_report,
)
# Style
plt.rcParams.update({
"figure.facecolor": "#0f0f1a",
"axes.facecolor": "#1a1a2e",
"axes.edgecolor": "#444",
"axes.labelcolor": "white",
"text.color": "white",
"xtick.color": "white",
"ytick.color": "white",
"grid.color": "#333",
"font.family": "DejaVu Sans",
})
RESULTS_DIR = "results"
PLOTS_DIR = os.path.join(RESULTS_DIR, "confusion_matrices")
os.makedirs(PLOTS_DIR, exist_ok=True)
# Track all model results for final comparison
ALL_RESULTS = {}
def evaluate_model(y_true, y_pred, y_proba=None, model_name: str = "Model",
split: str = "test", save_plots: bool = False,
X_texts=None) -> dict:
"""
Full evaluation suite:
- Accuracy, Precision, Recall, F1, ROC-AUC
- Confusion matrix (plotted)
- ROC curve (plotted)
- Error analysis (misclassified samples)
Args:
y_true: True labels.
y_pred: Predicted labels.
y_proba: Predicted positive-class probabilities (for ROC-AUC).
model_name: Name label for this model.
split: 'val' or 'test'.
save_plots: Save figures to results/ folder.
X_texts: Optional raw texts for error analysis.
Returns:
dict with all metric values.
"""
metrics = {
"model": model_name,
"split": split,
"accuracy": accuracy_score(y_true, y_pred),
"precision": precision_score(y_true, y_pred, average="binary", zero_division=0),
"recall": recall_score(y_true, y_pred, average="binary", zero_division=0),
"f1": f1_score(y_true, y_pred, average="binary", zero_division=0),
"roc_auc": roc_auc_score(y_true, y_proba) if y_proba is not None else None,
}
# Console report
print(f"\n{'─'*50}")
print(f" πŸ“Š {model_name} β€” {split.upper()} SET")
print(f"{'─'*50}")
print(f" Accuracy : {metrics['accuracy']:.4f}")
print(f" Precision : {metrics['precision']:.4f}")
print(f" Recall : {metrics['recall']:.4f}")
print(f" F1-Score : {metrics['f1']:.4f}")
if metrics["roc_auc"]:
print(f" ROC-AUC : {metrics['roc_auc']:.4f}")
print(f"{'─'*50}")
print(classification_report(y_true, y_pred,
target_names=["Negative", "Positive"]))
if save_plots:
_plot_confusion_matrix(y_true, y_pred, model_name, split)
if y_proba is not None:
_plot_roc_curve(y_true, y_proba, model_name, split, metrics["roc_auc"])
if split == "test":
ALL_RESULTS[model_name] = metrics
_save_metrics_csv()
if X_texts is not None and split == "test":
do_error_analysis(y_true, y_pred, y_proba, X_texts, model_name)
return metrics
# ──────────────────────────────────────────────
# Confusion Matrix
# ──────────────────────────────────────────────
def _plot_confusion_matrix(y_true, y_pred, model_name: str, split: str):
cm = confusion_matrix(y_true, y_pred)
cm_pct = cm.astype(float) / cm.sum(axis=1, keepdims=True)
fig, ax = plt.subplots(figsize=(7, 6))
fig.patch.set_facecolor("#0f0f1a")
ax.set_facecolor("#1a1a2e")
sns.heatmap(cm, annot=False, fmt="d", cmap="Blues", ax=ax,
linewidths=0.5, linecolor="#333",
cbar_kws={"shrink": 0.8})
# Annotate cells with count + percentage
labels = [["TN", "FP"], ["FN", "TP"]]
for i in range(2):
for j in range(2):
ax.text(j + 0.5, i + 0.35,
f"{labels[i][j]}\n{cm[i][j]:,}",
ha="center", va="center", fontsize=14,
color="white", fontweight="bold")
ax.text(j + 0.5, i + 0.65,
f"({cm_pct[i][j]:.1%})",
ha="center", va="center", fontsize=11, color="#aaa")
ax.set_xticklabels(["Negative", "Positive"], fontsize=12)
ax.set_yticklabels(["Negative", "Positive"], fontsize=12, rotation=0)
ax.set_xlabel("Predicted Label", fontsize=13, labelpad=10)
ax.set_ylabel("True Label", fontsize=13, labelpad=10)
ax.set_title(f"Confusion Matrix β€” {model_name}\n({split} set)",
fontsize=14, fontweight="bold", pad=15)
plt.tight_layout()
safe_name = model_name.replace(" ", "_").replace("/", "_")
path = os.path.join(PLOTS_DIR, f"cm_{safe_name}_{split}.png")
plt.savefig(path, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor())
plt.close()
print(f" πŸ“Š Confusion matrix saved β†’ {path}")
# ──────────────────────────────────────────────
# ROC Curve
# ──────────────────────────────────────────────
def _plot_roc_curve(y_true, y_proba, model_name: str, split: str, auc: float):
fpr, tpr, _ = roc_curve(y_true, y_proba)
fig, ax = plt.subplots(figsize=(7, 6))
fig.patch.set_facecolor("#0f0f1a")
ax.set_facecolor("#1a1a2e")
ax.plot(fpr, tpr, color="#6c63ff", lw=2.5,
label=f"AUC = {auc:.4f}")
ax.plot([0, 1], [0, 1], "r--", lw=1.5, label="Random Classifier")
ax.fill_between(fpr, tpr, alpha=0.15, color="#6c63ff")
ax.set_xlabel("False Positive Rate", fontsize=12)
ax.set_ylabel("True Positive Rate", fontsize=12)
ax.set_title(f"ROC Curve β€” {model_name} ({split})", fontsize=13, fontweight="bold")
ax.legend(fontsize=11, loc="lower right")
ax.grid(True, alpha=0.3)
plt.tight_layout()
safe_name = model_name.replace(" ", "_").replace("/", "_")
path = os.path.join(PLOTS_DIR, f"roc_{safe_name}_{split}.png")
plt.savefig(path, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor())
plt.close()
print(f" πŸ“ˆ ROC curve saved β†’ {path}")
# ──────────────────────────────────────────────
# Error Analysis
# ──────────────────────────────────────────────
def do_error_analysis(y_true, y_pred, y_proba, X_texts,
model_name: str, n_samples: int = 30):
"""
Identify and save misclassified samples with confidence scores.
Outputs:
- results/error_analysis_{model_name}.csv
- Console summary of error patterns
"""
y_true = np.array(y_true)
y_pred = np.array(y_pred)
y_proba = np.array(y_proba) if y_proba is not None else np.zeros(len(y_true))
X_texts = np.array(X_texts)
misclassified_mask = y_true != y_pred
n_errors = misclassified_mask.sum()
print(f"\nπŸ” Error Analysis β€” {model_name}")
print(f" Total errors: {n_errors}/{len(y_true)} "
f"({n_errors/len(y_true)*100:.1f}%)")
wrong_texts = X_texts[misclassified_mask]
wrong_true = y_true[misclassified_mask]
wrong_pred = y_pred[misclassified_mask]
wrong_conf = y_proba[misclassified_mask]
# Error types
fp_mask = (wrong_true == 0) & (wrong_pred == 1)
fn_mask = (wrong_true == 1) & (wrong_pred == 0)
print(f" False Positives (neg→pos): {fp_mask.sum()}")
print(f" False Negatives (pos→neg): {fn_mask.sum()}")
# High-confidence mistakes (model very wrong)
high_conf_errors = np.abs(wrong_conf - 0.5) > 0.3
print(f" High-confidence mistakes: {high_conf_errors.sum()}")
# Build DataFrame
error_df = pd.DataFrame({
"text": wrong_texts,
"true_label": ["Positive" if l == 1 else "Negative" for l in wrong_true],
"pred_label": ["Positive" if l == 1 else "Negative" for l in wrong_pred],
"error_type": ["FP" if fp else "FN" for fp, fn in zip(fp_mask, fn_mask)
for _ in [1]],
"confidence": wrong_conf,
"high_confidence_mistake": high_conf_errors,
})
# Trim text for readability
error_df["text_preview"] = error_df["text"].str[:200]
os.makedirs(RESULTS_DIR, exist_ok=True)
safe_name = model_name.replace(" ", "_").replace("/", "_")
out_path = os.path.join(RESULTS_DIR, f"error_analysis_{safe_name}.csv")
error_df.to_csv(out_path, index=False)
print(f" πŸ’Ύ Error analysis saved β†’ {out_path}")
# Print most confident mistakes
print(f"\n Top {min(5, n_errors)} most confident mistakes:")
top = error_df.sort_values("confidence", ascending=False).head(5)
for _, row in top.iterrows():
print(f" [{row['error_type']}] conf={row['confidence']:.3f} | "
f"'{row['text_preview'][:80]}...'")
return error_df
# ──────────────────────────────────────────────
# Comparison Chart
# ──────────────────────────────────────────────
def plot_model_comparison():
"""Plot a side-by-side comparison bar chart of all evaluated models."""
if not ALL_RESULTS:
print("⚠️ No results to compare yet.")
return
df = pd.DataFrame(ALL_RESULTS).T
metrics_to_plot = ["accuracy", "precision", "recall", "f1", "roc_auc"]
df = df[metrics_to_plot].astype(float)
fig, ax = plt.subplots(figsize=(11, 6))
fig.patch.set_facecolor("#0f0f1a")
ax.set_facecolor("#1a1a2e")
x = np.arange(len(metrics_to_plot))
width = 0.22
colors = ["#6c63ff", "#ff6584", "#43aa8b"]
for i, (model, row) in enumerate(df.iterrows()):
bars = ax.bar(x + i * width, row.values, width,
label=model, color=colors[i % len(colors)],
alpha=0.9, edgecolor="white", linewidth=0.5)
for bar in bars:
h = bar.get_height()
ax.text(bar.get_x() + bar.get_width() / 2, h + 0.005,
f"{h:.3f}", ha="center", va="bottom",
fontsize=8, color="white")
ax.set_xticks(x + width)
ax.set_xticklabels([m.replace("_", " ").upper() for m in metrics_to_plot],
fontsize=11)
ax.set_ylim(0.80, 1.01)
ax.set_ylabel("Score", fontsize=12)
ax.set_title("Model Comparison β€” Sentiment Analysis (IMDB)", fontsize=14,
fontweight="bold", pad=15)
ax.legend(fontsize=10, loc="lower right")
ax.grid(True, axis="y", alpha=0.3)
plt.tight_layout()
path = os.path.join(RESULTS_DIR, "model_comparison.png")
plt.savefig(path, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor())
plt.close()
print(f"\nπŸ“Š Comparison chart saved β†’ {path}")
def _save_metrics_csv():
"""Persist all model metrics to CSV."""
if not ALL_RESULTS:
return
df = pd.DataFrame(ALL_RESULTS).T
path = os.path.join(RESULTS_DIR, "metrics_summary.csv")
df.to_csv(path)
print(f" πŸ’Ύ Metrics summary β†’ {path}")