import numpy as np import logging from typing import Dict import numpy as np import pytrec_eval logger = logging.getLogger(__name__) class RankingMetrics: def __init__(self, metric_list, k_list=(1, 5, 10)): """ Initialize retrieval metrics. Args: metric_list (tuple or list): Metrics to compute ("precision", "recall", "ndcg", "map", "mrr"). k_list (tuple or list): List of K values (e.g., (1, 5, 10)) for evaluation. """ self.metric_list = metric_list self.k_list = sorted(k_list) # Ensure K is in ascending order def precision_at_k(self, prediction, true_labels, k): """ Compute Precision@K for multiple true labels. Precision@K = (Number of relevant items in top K) / K """ if not true_labels: # No relevant items return 0.0 if k == 0: return 0.0 predicted_k = prediction[:k] relevant_hits = len(set(predicted_k).intersection(set(true_labels))) return relevant_hits / k def recall_at_k(self, prediction, true_labels, k): """ Compute Recall@K for multiple true labels. Recall@K = (Number of relevant items in top K) / (Total number of relevant items) """ if not true_labels: # No relevant items return 1.0 # Or 0.0 depending on convention if no relevant items exist. # Typically, if there are no true positives, recall is undefined or 1 if no predictions are made. # Let's assume if true_labels is empty, it means perfect recall if prediction is also empty, or 0 if not. # A common convention is that if the set of true positives is empty, recall is 1.0. # However, if true_labels is empty but predictions are made, it might be 0. # Let's use the common definition: if len(true_labels) == 0, recall is 1.0 if no items were expected and none were retrieved. # If items were expected (true_labels not empty) but none retrieved, recall is 0. # If true_labels is empty and nothing is predicted, it's 1.0. # If true_labels is empty and something is predicted, it's 0.0. (This is debatable) # For simplicity here, if there are no true_labels, then no relevant items can be found. return 1.0 # All 0 relevant items were recalled if k == 0 and not true_labels: # no predictions and no true labels return 1.0 if k == 0 and true_labels: # no predictions but true labels exist return 0.0 predicted_k = prediction[:k] relevant_hits = len(set(predicted_k).intersection(set(true_labels))) return relevant_hits / len(true_labels) def _get_relevant_hits_and_predicted_k(self, prediction, true_labels, k): """Helper function to get common elements for calculations.""" if k == 0: return 0, [] predicted_k = prediction[:k] if not true_labels: # No relevant items return 0, predicted_k relevant_hits = len(set(predicted_k).intersection(set(true_labels))) return relevant_hits, predicted_k def hit_at_k(self, prediction, true_labels, k): """ Compute Hit@K (or Hit Rate@K). Returns 1.0 if at least one true label is found in the top K predictions, 0.0 otherwise. """ if not true_labels: # No relevant items to hit return 0.0 if k == 0: return 0.0 relevant_hits, _ = self._get_relevant_hits_and_predicted_k(prediction, true_labels, k) return 1.0 if relevant_hits > 0 else 0.0 def f1_at_k(self, prediction, true_labels, k): """ Compute F1-score@K. F1@K = 2 * (Precision@K * Recall@K) / (Precision@K + Recall@K) """ p_k = self.precision_at_k(prediction, true_labels, k) r_k = self.recall_at_k(prediction, true_labels, k) if (p_k + r_k) == 0: return 0.0 return 2 * (p_k * r_k) / (p_k + r_k) def average_precision_at_k(self, prediction, true_labels, k): """ Compute Average Precision (AP)@K for multiple true labels. AP is the average of precision@i for i where the i-th item is relevant. """ if not true_labels: return 0.0 # Or 1.0 depending on definition if no relevant items. Let's use 0.0 if no true positives. if k == 0: return 0.0 predicted_k = prediction[:k] score = 0.0 num_hits = 0 for i, p in enumerate(predicted_k): if p in true_labels: num_hits += 1 score += num_hits / (i + 1.0) # Precision at this new hit if not num_hits: # No hits among top k return 0.0 return score / min(len(true_labels), k) # Normalize by min of total relevant or K for AP@K # Standard AP normalizes by total relevant items (len(true_labels)) def mean_average_precision_at_k(self, test_cases_for_map_mrr, k_val): """Helper for MAP@k""" ap_scores = [] for case in test_cases_for_map_mrr: prediction, true_labels_list = case["prediction"], case["label"] # Ensure true_labels_list is always a list for consistency if not isinstance(true_labels_list, list): true_labels_list = [true_labels_list] ap_scores.append(self.average_precision_at_k(prediction, true_labels_list, k_val)) return np.mean(ap_scores) if ap_scores else 0.0 def mean_reciprocal_rank_at_k(self, test_cases_for_map_mrr, k_val): """Helper for MRR@k""" rr_scores = [] for case in test_cases_for_map_mrr: prediction, true_labels_list = case["prediction"], case["label"] if not isinstance(true_labels_list, list): true_labels_list = [true_labels_list] if not true_labels_list: rr_scores.append(0.0) # Or handle as per definition, no relevant items. continue predicted_k = prediction[:k_val] for rank, p_item in enumerate(predicted_k): if p_item in true_labels_list: rr_scores.append(1.0 / (rank + 1)) break # Found first relevant item else: # No relevant item found in top K rr_scores.append(0.0) return np.mean(rr_scores) if rr_scores else 0.0 def ndcg_at_k(self, prediction, true_labels, k, rel_scores, form="linear"): """ Compute Normalized Discounted Cumulative Gain (NDCG@K) with optional graded relevance and exponential gain. Args: prediction (list): Ranked list of predicted document IDs. true_labels (str or list): Relevant document(s). rel_scores (list or None): If provided, must be same length as label. Interpreted as graded relevance. If None, use binary relevance. """ def dcg(rel_list, form): if form == "linear": return sum(((rel) / np.log2(idx + 2)) for idx, rel in enumerate(rel_list)) elif form == "exponential": return sum(((2 ** rel - 1) / np.log2(idx + 2)) for idx, rel in enumerate(rel_list)) if rel_scores is None: # Binary relevance if isinstance(true_labels, list): label_set = set(true_labels) else: label_set = {true_labels} relevance = [1 if item in label_set else 0 for item in prediction[:k]] ideal_relevance = [1] * min(len(label_set), k) else: # Graded relevance if not isinstance(true_labels, list) or not isinstance(rel_scores, list) or len(true_labels) != len(rel_scores): raise ValueError("If rels is provided, label must be a list of the same length.") label_rels = dict(zip(true_labels, rel_scores)) relevance = [label_rels.get(item, 0) for item in prediction[:k]] ideal_relevance = sorted(label_rels.values(), reverse=True)[:k] dcg_score = dcg(relevance, form) idcg_score = dcg(ideal_relevance, form) if ideal_relevance else 1 # Avoid division by zero return dcg_score / idcg_score if idcg_score > 0 else 0.0 def evaluate(self, test_cases): """ Evaluates predictions against true labels for a list of test cases. Handles multi-label by ensuring `label` is a list of true positives. Args: test_cases (list of dict): Each dict should have "prediction" (list) and "label" (list of true positive labels). """ metric_results_accumulators = { (f"ndcg_{v}" if metric == "ndcg" else metric): {k: [] for k in self.k_list} for metric in self.metric_list for v in (["linear", "exponential"] if metric == "ndcg" else [None]) } processed_test_cases_for_map_mrr = [] for case_idx, case in enumerate(test_cases): prediction = case["prediction"] true_labels = case["label"] if not isinstance(true_labels, (list, set)): # Allow set for true_labels true_labels = [true_labels] true_labels = list(set(true_labels)) # Ensure unique true labels and convert to list for consistency # For MAP and MRR, we often want the original list of true_labels, # even if empty, for consistent calculation across cases. if "map" in self.metric_list or "mrr" in self.metric_list: processed_test_cases_for_map_mrr.append( {"prediction": prediction, "label": true_labels, "id": case_idx} ) for k in self.k_list: if "precision" in self.metric_list: metric_results_accumulators["precision"][k].append(self.precision_at_k(prediction, true_labels, k)) if "recall" in self.metric_list: metric_results_accumulators["recall"][k].append(self.recall_at_k(prediction, true_labels, k)) if "hit" in self.metric_list: metric_results_accumulators["hit"][k].append(self.hit_at_k(prediction, true_labels, k)) if "f1" in self.metric_list: metric_results_accumulators["f1"][k].append(self.f1_at_k(prediction, true_labels, k)) if "ndcg" in self.metric_list: rel_scores = case["rel_scores"] metric_results_accumulators["ndcg_linear"][k].append( self.ndcg_at_k(prediction, true_labels, k, rel_scores, form="linear")) metric_results_accumulators["ndcg_exponential"][k].append( self.ndcg_at_k(prediction, true_labels, k, rel_scores, form="exponential")) # Calculate final mean for most metrics score_dict = {} for metric, k_values_dict in metric_results_accumulators.items(): if metric not in ["map", "mrr"]: # MAP/MRR calculated differently for k, values in k_values_dict.items(): score_dict[f"{metric}@{k}"] = float(np.mean(values)) if values else 0.0 # Calculate MAP and MRR across all test cases for each k if "map" in self.metric_list: for k_val in self.k_list: ap_scores = [ self.average_precision_at_k(case["prediction"], case["label"], k_val) for case in processed_test_cases_for_map_mrr ] score_dict[f"map@{k_val}"] = float(np.mean(ap_scores)) if ap_scores else 0.0 if "mrr" in self.metric_list: for k_val in self.k_list: rr_scores = [] for case in processed_test_cases_for_map_mrr: reciprocal_rank = 0.0 if case["label"]: # Only if there are true labels for rank, p_item in enumerate(case["prediction"][:k_val]): if p_item in case["label"]: reciprocal_rank = 1.0 / (rank + 1) break rr_scores.append(reciprocal_rank) score_dict[f"mrr@{k_val}"] = float(np.mean(rr_scores)) if rr_scores else 0.0 return score_dict