LexiMind / scripts /evaluate.py
OliverPerrin
updated readme, ruff formatted all files
df3ebbd
"""
Comprehensive evaluation script for LexiMind.
Evaluates all three tasks with full metrics:
- Summarization: ROUGE-1/2/L, BLEU-4, per-domain breakdown (BERTScore optional)
- Emotion: Sample-avg F1, Macro F1, Micro F1, per-class metrics, threshold tuning
- Topic: Accuracy, Macro F1, Per-class metrics, bootstrap confidence intervals
Usage:
python scripts/evaluate.py
python scripts/evaluate.py --checkpoint checkpoints/best.pt
python scripts/evaluate.py --include-bertscore # Include BERTScore (slow)
python scripts/evaluate.py --tune-thresholds # Tune per-class emotion thresholds
python scripts/evaluate.py --bootstrap # Compute confidence intervals
Author: Oliver Perrin
Date: January 2026
"""
from __future__ import annotations
import argparse
import json
import sys
import time
from pathlib import Path
# Setup path
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
import torch
from sklearn.metrics import accuracy_score, classification_report, f1_score
from tqdm import tqdm
from src.data.dataset import (
load_emotion_jsonl,
load_summarization_jsonl,
load_topic_jsonl,
)
from src.inference.factory import create_inference_pipeline
from src.training.metrics import (
bootstrap_confidence_interval,
calculate_bertscore,
calculate_bleu,
calculate_rouge,
multilabel_f1,
multilabel_macro_f1,
multilabel_micro_f1,
multilabel_per_class_metrics,
tune_per_class_thresholds,
)
def evaluate_summarization(
pipeline,
data_path: Path,
max_samples: int | None = None,
include_bertscore: bool = True,
batch_size: int = 8,
compute_bootstrap: bool = False,
) -> dict:
"""Evaluate summarization with comprehensive metrics and per-domain breakdown."""
print("\n" + "=" * 60)
print("SUMMARIZATION EVALUATION")
print("=" * 60)
# Load data - try to get domain info from the raw JSONL
raw_data = []
with open(data_path) as f:
for line in f:
if line.strip():
raw_data.append(json.loads(line))
data = load_summarization_jsonl(str(data_path))
if max_samples:
data = data[:max_samples]
raw_data = raw_data[:max_samples]
print(f"Evaluating on {len(data)} samples...")
# Generate summaries
predictions = []
references = []
domains = [] # Track domain for per-domain breakdown
for i in tqdm(range(0, len(data), batch_size), desc="Generating summaries"):
batch = data[i : i + batch_size]
sources = [ex.source for ex in batch]
refs = [ex.summary for ex in batch]
preds = pipeline.summarize(sources)
predictions.extend(preds)
references.extend(refs)
# Track domain if available
for j in range(len(batch)):
idx = i + j
if idx < len(raw_data):
domain = raw_data[idx].get("type", raw_data[idx].get("domain", "unknown"))
domains.append(domain)
else:
domains.append("unknown")
# Calculate overall metrics
print("\nCalculating ROUGE scores...")
rouge_scores = calculate_rouge(predictions, references)
print("Calculating BLEU score...")
bleu = calculate_bleu(predictions, references)
metrics: dict = {
"rouge1": rouge_scores["rouge1"],
"rouge2": rouge_scores["rouge2"],
"rougeL": rouge_scores["rougeL"],
"bleu4": bleu,
"num_samples": len(predictions),
}
if include_bertscore:
print("Calculating BERTScore (this may take a few minutes)...")
bert_scores = calculate_bertscore(predictions, references)
metrics["bertscore_precision"] = bert_scores["precision"]
metrics["bertscore_recall"] = bert_scores["recall"]
metrics["bertscore_f1"] = bert_scores["f1"]
# Per-domain breakdown
unique_domains = sorted(set(domains))
if len(unique_domains) > 1:
print("\nComputing per-domain breakdown...")
domain_metrics = {}
for domain in unique_domains:
if domain == "unknown":
continue
d_preds = [p for p, d in zip(predictions, domains, strict=True) if d == domain]
d_refs = [r for r, d in zip(references, domains, strict=True) if d == domain]
if not d_preds:
continue
d_rouge = calculate_rouge(d_preds, d_refs)
d_bleu = calculate_bleu(d_preds, d_refs)
dm: dict = {
"num_samples": len(d_preds),
"rouge1": d_rouge["rouge1"],
"rouge2": d_rouge["rouge2"],
"rougeL": d_rouge["rougeL"],
"bleu4": d_bleu,
}
if include_bertscore:
d_bert = calculate_bertscore(d_preds, d_refs)
dm["bertscore_f1"] = d_bert["f1"]
domain_metrics[domain] = dm
metrics["per_domain"] = domain_metrics
# Bootstrap confidence intervals
if compute_bootstrap:
try:
from rouge_score import rouge_scorer
scorer = rouge_scorer.RougeScorer(["rouge1", "rougeL"], use_stemmer=True)
per_sample_r1 = []
per_sample_rL = []
for pred, ref in zip(predictions, references, strict=True):
scores = scorer.score(ref, pred)
per_sample_r1.append(scores["rouge1"].fmeasure)
per_sample_rL.append(scores["rougeL"].fmeasure)
r1_mean, r1_lo, r1_hi = bootstrap_confidence_interval(per_sample_r1)
rL_mean, rL_lo, rL_hi = bootstrap_confidence_interval(per_sample_rL)
metrics["rouge1_ci"] = {"mean": r1_mean, "lower": r1_lo, "upper": r1_hi}
metrics["rougeL_ci"] = {"mean": rL_mean, "lower": rL_lo, "upper": rL_hi}
except ImportError:
pass
# Print results
print("\n" + "-" * 40)
print("SUMMARIZATION RESULTS:")
print("-" * 40)
print(f" ROUGE-1: {metrics['rouge1']:.4f}")
print(f" ROUGE-2: {metrics['rouge2']:.4f}")
print(f" ROUGE-L: {metrics['rougeL']:.4f}")
print(f" BLEU-4: {metrics['bleu4']:.4f}")
if include_bertscore:
print(f" BERTScore P: {metrics['bertscore_precision']:.4f}")
print(f" BERTScore R: {metrics['bertscore_recall']:.4f}")
print(f" BERTScore F: {metrics['bertscore_f1']:.4f}")
if "per_domain" in metrics:
print("\n Per-Domain Breakdown:")
for domain, dm in metrics["per_domain"].items():
bs_str = f", BS-F1={dm['bertscore_f1']:.4f}" if "bertscore_f1" in dm else ""
print(
f" {domain} (n={dm['num_samples']}): R1={dm['rouge1']:.4f}, RL={dm['rougeL']:.4f}, B4={dm['bleu4']:.4f}{bs_str}"
)
if "rouge1_ci" in metrics:
ci = metrics["rouge1_ci"]
print(f"\n ROUGE-1 95% CI: [{ci['lower']:.4f}, {ci['upper']:.4f}]")
# Show examples
print("\n" + "-" * 40)
print("SAMPLE OUTPUTS:")
print("-" * 40)
for i in range(min(3, len(predictions))):
print(f"\nExample {i + 1}:")
print(f" Source: {data[i].source[:100]}...")
print(f" Generated: {predictions[i][:150]}...")
print(f" Reference: {references[i][:150]}...")
return metrics
def evaluate_emotion(
pipeline,
data_path: Path,
max_samples: int | None = None,
batch_size: int = 32,
tune_thresholds: bool = False,
compute_bootstrap: bool = False,
) -> dict:
"""Evaluate emotion detection with comprehensive multi-label metrics.
Reports sample-averaged F1, macro F1, micro F1, and per-class breakdown.
Optionally tunes per-class thresholds on the evaluation set.
"""
print("\n" + "=" * 60)
print("EMOTION DETECTION EVALUATION")
print("=" * 60)
# Load data (returns EmotionExample dataclass objects)
data = load_emotion_jsonl(str(data_path))
if max_samples:
data = data[:max_samples]
print(f"Evaluating on {len(data)} samples...")
# Get predictions - collect raw logits for threshold tuning
all_preds = []
all_refs = []
all_logits_list = []
for i in tqdm(range(0, len(data), batch_size), desc="Predicting emotions"):
batch = data[i : i + batch_size]
texts = [ex.text for ex in batch]
refs = [set(ex.emotions) for ex in batch]
preds = pipeline.predict_emotions(texts)
pred_sets = [set(p.labels) for p in preds]
all_preds.extend(pred_sets)
all_refs.extend(refs)
# Also get raw logits for threshold tuning
if tune_thresholds:
encoded = pipeline.tokenizer.batch_encode(texts)
input_ids = encoded["input_ids"].to(pipeline.device)
attention_mask = encoded["attention_mask"].to(pipeline.device)
with torch.inference_mode():
logits = pipeline.model.forward(
"emotion", {"input_ids": input_ids, "attention_mask": attention_mask}
)
all_logits_list.append(logits.cpu())
# Calculate metrics
all_emotions = sorted(pipeline.emotion_labels)
def to_binary(emotion_sets, labels):
return [[1 if e in es else 0 for e in labels] for es in emotion_sets]
pred_binary = torch.tensor(to_binary(all_preds, all_emotions))
ref_binary = torch.tensor(to_binary(all_refs, all_emotions))
# Core metrics: sample-avg F1, macro F1, micro F1
sample_f1 = multilabel_f1(pred_binary, ref_binary)
macro_f1 = multilabel_macro_f1(pred_binary, ref_binary)
micro_f1 = multilabel_micro_f1(pred_binary, ref_binary)
# Per-class metrics
per_class = multilabel_per_class_metrics(pred_binary, ref_binary, class_names=all_emotions)
metrics: dict = {
"sample_avg_f1": sample_f1,
"macro_f1": macro_f1,
"micro_f1": micro_f1,
"num_samples": len(all_preds),
"num_classes": len(all_emotions),
"per_class": per_class,
}
# Per-class threshold tuning
if tune_thresholds and all_logits_list:
print("\nTuning per-class thresholds...")
all_logits = torch.cat(all_logits_list, dim=0)
best_thresholds, tuned_macro_f1 = tune_per_class_thresholds(all_logits, ref_binary)
metrics["tuned_thresholds"] = {
name: thresh for name, thresh in zip(all_emotions, best_thresholds, strict=True)
}
metrics["tuned_macro_f1"] = tuned_macro_f1
# Also compute tuned sample-avg F1
probs = torch.sigmoid(all_logits)
tuned_preds = torch.zeros_like(probs)
for c, t in enumerate(best_thresholds):
tuned_preds[:, c] = (probs[:, c] >= t).float()
metrics["tuned_sample_avg_f1"] = multilabel_f1(tuned_preds, ref_binary)
metrics["tuned_micro_f1"] = multilabel_micro_f1(tuned_preds, ref_binary)
# Bootstrap confidence intervals
if compute_bootstrap:
# Compute per-sample F1 for bootstrapping
per_sample_f1s = []
for pred, ref in zip(all_preds, all_refs, strict=True):
if len(pred) == 0 and len(ref) == 0:
per_sample_f1s.append(1.0)
elif len(pred) == 0 or len(ref) == 0:
per_sample_f1s.append(0.0)
else:
intersection = len(pred & ref)
p = intersection / len(pred) if pred else 0
r = intersection / len(ref) if ref else 0
per_sample_f1s.append(2 * p * r / (p + r) if (p + r) > 0 else 0.0)
mean, lo, hi = bootstrap_confidence_interval(per_sample_f1s)
metrics["sample_f1_ci"] = {"mean": mean, "lower": lo, "upper": hi}
# Print results
print("\n" + "-" * 40)
print("EMOTION DETECTION RESULTS:")
print("-" * 40)
print(f" Sample-avg F1: {metrics['sample_avg_f1']:.4f}")
print(f" Macro F1: {metrics['macro_f1']:.4f}")
print(f" Micro F1: {metrics['micro_f1']:.4f}")
print(f" Num Classes: {metrics['num_classes']}")
if "tuned_macro_f1" in metrics:
print("\n After per-class threshold tuning:")
print(f" Tuned Macro F1: {metrics['tuned_macro_f1']:.4f}")
print(f" Tuned Sample-avg F1: {metrics['tuned_sample_avg_f1']:.4f}")
print(f" Tuned Micro F1: {metrics['tuned_micro_f1']:.4f}")
if "sample_f1_ci" in metrics:
ci = metrics["sample_f1_ci"]
print(f"\n Sample F1 95% CI: [{ci['lower']:.4f}, {ci['upper']:.4f}]")
# Print top-10 per-class performance
print("\n Per-class F1 (top 10 by support):")
sorted_classes = sorted(per_class.items(), key=lambda x: x[1]["support"], reverse=True)
for name, m in sorted_classes[:10]:
print(
f" {name:20s}: P={m['precision']:.3f} R={m['recall']:.3f} F1={m['f1']:.3f} (n={m['support']})"
)
return metrics
def evaluate_topic(
pipeline,
data_path: Path,
max_samples: int | None = None,
batch_size: int = 32,
compute_bootstrap: bool = False,
) -> dict:
"""Evaluate topic classification with per-class metrics and optional bootstrap CI."""
print("\n" + "=" * 60)
print("TOPIC CLASSIFICATION EVALUATION")
print("=" * 60)
# Load data (returns TopicExample dataclass objects)
data = load_topic_jsonl(str(data_path))
if max_samples:
data = data[:max_samples]
print(f"Evaluating on {len(data)} samples...")
# Get predictions
all_preds = []
all_refs = []
for i in tqdm(range(0, len(data), batch_size), desc="Predicting topics"):
batch = data[i : i + batch_size]
texts = [ex.text for ex in batch]
refs = [ex.topic for ex in batch]
preds = pipeline.predict_topics(texts)
pred_labels = [p.label for p in preds]
all_preds.extend(pred_labels)
all_refs.extend(refs)
# Calculate metrics
accuracy = accuracy_score(all_refs, all_preds)
macro_f1 = f1_score(all_refs, all_preds, average="macro", zero_division=0)
metrics: dict = {
"accuracy": accuracy,
"macro_f1": macro_f1,
"num_samples": len(all_preds),
}
# Bootstrap confidence intervals for accuracy
if compute_bootstrap:
per_sample_correct = [
1.0 if p == r else 0.0 for p, r in zip(all_preds, all_refs, strict=True)
]
mean, lo, hi = bootstrap_confidence_interval(per_sample_correct)
metrics["accuracy_ci"] = {"mean": mean, "lower": lo, "upper": hi}
# Print results
print("\n" + "-" * 40)
print("TOPIC CLASSIFICATION RESULTS:")
print("-" * 40)
print(f" Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy'] * 100:.1f}%)")
print(f" Macro F1: {metrics['macro_f1']:.4f}")
if "accuracy_ci" in metrics:
ci = metrics["accuracy_ci"]
print(f" Accuracy 95% CI: [{ci['lower']:.4f}, {ci['upper']:.4f}]")
# Classification report
print("\n" + "-" * 40)
print("PER-CLASS METRICS:")
print("-" * 40)
print(classification_report(all_refs, all_preds, zero_division=0))
return metrics
def main():
parser = argparse.ArgumentParser(description="Evaluate LexiMind model")
parser.add_argument("--checkpoint", type=Path, default=Path("checkpoints/best.pt"))
parser.add_argument("--labels", type=Path, default=Path("artifacts/labels.json"))
parser.add_argument("--data-dir", type=Path, default=Path("data/processed"))
parser.add_argument("--output", type=Path, default=Path("outputs/evaluation_report.json"))
parser.add_argument("--max-samples", type=int, default=None, help="Limit samples per task")
parser.add_argument(
"--include-bertscore", action="store_true", help="Include BERTScore (slow, optional)"
)
parser.add_argument(
"--tune-thresholds",
action="store_true",
help="Tune per-class emotion thresholds on val set",
)
parser.add_argument(
"--bootstrap", action="store_true", help="Compute bootstrap confidence intervals"
)
parser.add_argument("--summarization-only", action="store_true")
parser.add_argument("--emotion-only", action="store_true")
parser.add_argument("--topic-only", action="store_true")
args = parser.parse_args()
print("=" * 60)
print("LexiMind Evaluation")
print("=" * 60)
start_time = time.perf_counter()
# Load model
print(f"\nLoading model from {args.checkpoint}...")
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline, labels = create_inference_pipeline(
args.checkpoint,
args.labels,
device=device,
)
print(f" Device: {device}")
print(f" Topics: {labels.topic}")
print(f" Emotions: {len(labels.emotion)} classes")
results = {}
# Determine which tasks to evaluate
eval_all = not (args.summarization_only or args.emotion_only or args.topic_only)
# Evaluate summarization
if eval_all or args.summarization_only:
val_path = args.data_dir / "summarization" / "validation.jsonl"
if not val_path.exists():
val_path = args.data_dir / "summarization" / "val.jsonl"
if val_path.exists():
results["summarization"] = evaluate_summarization(
pipeline,
val_path,
max_samples=args.max_samples,
include_bertscore=args.include_bertscore,
compute_bootstrap=args.bootstrap,
)
else:
print("Warning: summarization validation data not found, skipping")
# Evaluate emotion
if eval_all or args.emotion_only:
val_path = args.data_dir / "emotion" / "validation.jsonl"
if not val_path.exists():
val_path = args.data_dir / "emotion" / "val.jsonl"
if val_path.exists():
results["emotion"] = evaluate_emotion(
pipeline,
val_path,
max_samples=args.max_samples,
tune_thresholds=args.tune_thresholds,
compute_bootstrap=args.bootstrap,
)
else:
print("Warning: emotion validation data not found, skipping")
# Evaluate topic
if eval_all or args.topic_only:
val_path = args.data_dir / "topic" / "validation.jsonl"
if not val_path.exists():
val_path = args.data_dir / "topic" / "val.jsonl"
if val_path.exists():
results["topic"] = evaluate_topic(
pipeline,
val_path,
max_samples=args.max_samples,
compute_bootstrap=args.bootstrap,
)
else:
print("Warning: topic validation data not found, skipping")
# Save results
print("\n" + "=" * 60)
print("SAVING RESULTS")
print("=" * 60)
args.output.parent.mkdir(parents=True, exist_ok=True)
with open(args.output, "w") as f:
json.dump(results, f, indent=2)
print(f" Saved to: {args.output}")
# Final summary
elapsed = time.perf_counter() - start_time
print("\n" + "=" * 60)
print("EVALUATION COMPLETE")
print("=" * 60)
print(f" Time: {elapsed / 60:.1f} minutes")
if "summarization" in results:
s = results["summarization"]
print("\n Summarization:")
print(f" ROUGE-1: {s['rouge1']:.4f}")
print(f" ROUGE-2: {s['rouge2']:.4f}")
print(f" ROUGE-L: {s['rougeL']:.4f}")
print(f" BLEU-4: {s['bleu4']:.4f}")
if "bertscore_f1" in s:
print(f" BERTScore F1: {s['bertscore_f1']:.4f}")
if "emotion" in results:
e = results["emotion"]
print("\n Emotion:")
print(f" Sample-avg F1: {e['sample_avg_f1']:.4f}")
print(f" Macro F1: {e['macro_f1']:.4f}")
print(f" Micro F1: {e['micro_f1']:.4f}")
if "topic" in results:
print("\n Topic:")
print(f" Accuracy: {results['topic']['accuracy']:.2%}")
if __name__ == "__main__":
main()