| import math |
|
|
| def get_topk_results(predictions, scores, targets, k, all_items=None): |
| |
| results = [] |
| B = len(targets) |
| predictions = [_.split("Response:")[-1] for _ in predictions] |
| predictions = [_.strip().replace(" ","") for _ in predictions] |
| |
|
|
| if all_items is not None: |
| for i, seq in enumerate(predictions): |
| if seq not in all_items: |
| scores[i] = -1000 |
|
|
| for b in range(B): |
| batch_seqs = predictions[b * k: (b + 1) * k] |
| batch_scores = scores[b * k: (b + 1) * k] |
|
|
| pairs = [(a, b) for a, b in zip(batch_seqs, batch_scores)] |
| sorted_pairs = sorted(pairs, key=lambda x: x[1], reverse=True) |
| target_item = targets[b] |
| one_results = [] |
| for sorted_pred in sorted_pairs: |
| if sorted_pred[0] == target_item: |
| one_results.append(1) |
| else: |
| one_results.append(0) |
|
|
| results.append(one_results) |
|
|
| |
| return results |
|
|
| def get_metrics_results(topk_results, metrics): |
| res = {} |
| for m in metrics: |
| if m.lower().startswith("hit"): |
| k = int(m.split("@")[1]) |
| res[m] = hit_k(topk_results, k) |
| elif m.lower().startswith("ndcg"): |
| k = int(m.split("@")[1]) |
| res[m] = ndcg_k(topk_results, k) |
| else: |
| raise NotImplementedError |
|
|
| return res |
|
|
|
|
| def ndcg_k(topk_results, k): |
|
|
| ndcg = 0.0 |
| for row in topk_results: |
| res = row[:k] |
| one_ndcg = 0.0 |
| for i in range(len(res)): |
| one_ndcg += res[i] / math.log(i + 2, 2) |
| ndcg += one_ndcg |
| return ndcg |
|
|
|
|
| def hit_k(topk_results, k): |
| hit = 0.0 |
| for row in topk_results: |
| res = row[:k] |
| if sum(res) > 0: |
| hit += 1 |
| return hit |
|
|
|
|