import json import re import string import sys import random from argparse import ArgumentParser from collections import Counter from evaluate import load bertscore = load("bertscore") refer_file_path = sys.argv[1] input_file_path = sys.argv[2] conversations = open(refer_file_path, "r").readlines() conversations_dict = {} for conversation in conversations: conv_l = json.loads(conversation.strip()) conversations_dict[conv_l["question_id"]] = (conv_l["text"], conv_l["answer"], conv_l["type"]) class Metrics(): def __init__(self): pass def __normalize_text(self, s_text): """Lower text and remove punctuation, storys and extra whitespace.""" def remove_articles(text): regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) return re.sub(regex, ' ', text) def white_space_fix(text): return ' '.join(text.split()) def remove_punc(text): exclude = set(string.punctuation) return ''.join(ch for ch in text if ch not in exclude) def lower(text): return text.lower() return white_space_fix(remove_articles(remove_punc(lower(s_text)))) def __normalize_model_outputs(self, model_text, type_category): """post process of memo writing outputs""" extracted_elements = [re.sub(r'\s+', ' ', mt.replace('"', '').replace("'", "")) for mt in re.findall(r"'[^']*'|\"[^\"]*\"|\d+", model_text)] model_outputs = [] ti = 0 if "dialogsum" in type_category: while ti + 7 < len(extracted_elements): if extracted_elements[ti] == "topic" and extracted_elements[ti + 2] == "summary" and extracted_elements[ti + 4] == "start" and extracted_elements[ti + 6] == "end": try: model_outputs.append({"topic": extracted_elements[ti + 1], "summary": extracted_elements[ti + 3], "start": int(extracted_elements[ti + 5]), "end": int(extracted_elements[ti + 7])}) except: pass ti += 1 else: while ti + 5 < len(extracted_elements): if extracted_elements[ti] == "topic" and extracted_elements[ti + 2] == "start" and extracted_elements[ti + 4] == "end": try: model_outputs.append({"topic": extracted_elements[ti + 1], "start": int(extracted_elements[ti + 3]), "end": int(extracted_elements[ti + 5])}) except: pass ti += 1 return model_outputs def __get_class_span_dict__(self, label, checkitem_k): class_span = {} for i in range(len(label)): checkitem_i = self.__normalize_text(label[i][checkitem_k]) class_span[(label[i]['start'], label[i]['end'])] = class_span.get((label[i]['start'], label[i]['end']), []) + [checkitem_i] return class_span def __get_intersect_by_entity__(self, pred_class_span, label_class_span): ''' return the count of correct entity ''' cnt = 0 for label in label_class_span: cnt += len(list(set(label_class_span[label]).intersection(set(pred_class_span.get(label,[]))))) return cnt def __get_bertscore_by_entity__(self, pred_class_span, label_class_span): ''' return the count of correct entity ''' cnt = 0 for label in label_class_span: if label in pred_class_span: references = [label_class_span[label]] prediction = [pred_class_span[label][0]] result = bertscore.compute(predictions=prediction, references=references, model_type="microsoft/deberta-xlarge-mnli")["precision"][0] cnt += result return cnt def __get_cnt__(self, label_class_span): ''' return the count of entities ''' cnt = 0 for label in label_class_span: cnt += len(label_class_span[label]) # cnt += 1 # set as 1 if we have multiple references return cnt def metrics_by_entity_(self, pred, label, checkitem_k): ''' return entity level count of total prediction, true labels, and correct prediction ''' pred_class_span = self.__get_class_span_dict__(pred, checkitem_k) label_class_span = self.__get_class_span_dict__(label, checkitem_k) pred_cnt = self.__get_cnt__(pred_class_span) label_cnt = self.__get_cnt__(label_class_span) if checkitem_k == "topic": correct_cnt = self.__get_intersect_by_entity__(pred_class_span, label_class_span) elif checkitem_k == "summary": correct_cnt = self.__get_bertscore_by_entity__(pred_class_span, label_class_span) return pred_cnt, label_cnt, correct_cnt def p_r_f1_by_entity(self, pc, lc, cc): precision = cc / (pc + 1e-8) recall = cc / (lc + 1e-8) f1 = 2 * precision * recall / (precision + recall + 1e-8) return round(precision * 100, 2), round(recall * 100, 2), round(f1 * 100, 2) def metrics_by_entity_files(self, pred_file, checkitem_k, type_key): pred_cnt = 0 label_cnt = 0 correct_cnt = 0 for l_i, line in enumerate(open(pred_file, "r").readlines()): eles = json.loads(line.strip()) if (type_key not in conversations_dict[eles["question_id"]][2]) or (conversations_dict[eles["question_id"]][2] == "writing_topiocqa" and checkitem_k == "summary"): continue if type_key == "writing": model_text = self.__normalize_model_outputs(eles["text"], conversations_dict[eles["question_id"]][2]) label_i = json.loads(conversations_dict[eles["question_id"]][1]) elif type_key == "retrieval": model_text = [{"topic": v, "start": 0, "end": 0} for v in set(eles["text"].split("#"))] label_i = [{"topic": v, "start": 0, "end": 0} for v in set(conversations_dict[eles["question_id"]][1].split("#"))] else: model_text = [{"summary": eles["text"], "start": 0, "end": 0}] label_i = [{"summary": conversations_dict[eles["question_id"]][1], "start": 0, "end": 0}] p_cnt, l_cnt, c_cnt = self.metrics_by_entity_(model_text, label_i, checkitem_k) p_i, r_i, f_i = self.p_r_f1_by_entity(p_cnt, l_cnt, c_cnt) # if p_i + r_i + f_i != 0: # print("Q ID: " + str(eles["question_id"]) + "\n") # print(conversations_dict[eles["question_id"]][0] + "\n") # # print("Raw Ouput: " + eles["text"] + "\n") # print("Model: {}".format(model_text) + "\n") # print("Refer: {}".format(label_i) + "\n") # print("Case P/R/F1: {}%, {}%, {}%".format(p_i, r_i, f_i)) # print("=" * 20) pred_cnt += p_cnt label_cnt += l_cnt correct_cnt += c_cnt return self.p_r_f1_by_entity(pred_cnt, label_cnt, correct_cnt) calculate_metrics = Metrics() p_a, r_a, f1_a = calculate_metrics.metrics_by_entity_files(input_file_path, 'topic', 'writing') # both print("Overall P/R/F1 of topic: {}%, {}%, {}%".format(p_a, r_a, f1_a)) p_b, r_b, f1_b = calculate_metrics.metrics_by_entity_files(input_file_path, 'summary', 'writing') # dialogsum print("Overall P/R/F1 of summary: {}%, {}%, {}%".format(p_b, r_b, f1_b)) _, _, f1 = calculate_metrics.metrics_by_entity_files(input_file_path, "topic", "retrieval") # both print("Retrival F1: {}%".format(f1)) p, _, _ = calculate_metrics.metrics_by_entity_files(input_file_path, "summary", "chatting") # dialogsum print("Chatting similarity: {}%".format(p))