| | import multiprocessing |
| | from math import isclose |
| | import numpy as np |
| | from typing import Union, Any, Dict |
| |
|
| | from sympy import simplify, N |
| | from sympy.parsing.sympy_parser import parse_expr |
| | from sympy.parsing.latex import parse_latex |
| | import re |
| | import regex |
| |
|
| | from data_processing.answer_extraction import ( |
| | extract_answer, |
| | extract_program_output, |
| | strip_string, |
| | ) |
| |
|
| |
|
| | def extract_program(result: str, last_only=True): |
| | """ |
| | extract the program after "```python", and before "```" |
| | """ |
| | program = "" |
| | start = False |
| | for line in result.split("\n"): |
| | if line.startswith("```python"): |
| | if last_only: |
| | program = "" |
| | else: |
| | program += "\n# ========\n" |
| | start = True |
| | elif line.startswith("```"): |
| | start = False |
| | elif start: |
| | program += line + "\n" |
| | return program |
| |
|
| |
|
| | def parse_ground_truth(example: Dict[str, Any], data_name): |
| | if "gt_cot" in example: |
| | return example["gt_cot"], strip_string(example["gt"]) |
| |
|
| | |
| | if data_name in ["math", "ocw"]: |
| | gt_cot = example["solution"] |
| | gt_ans = extract_answer(gt_cot) |
| | elif data_name == "gsm8k": |
| | gt_cot, gt_ans = example["answer"].split("####") |
| | elif data_name == "gsm-hard": |
| | gt_cot, gt_ans = example["code"], example["target"] |
| | elif data_name == "svamp": |
| | gt_cot, gt_ans = example["Equation"], example["Answer"] |
| | elif data_name == "asdiv": |
| | gt_cot = example["formula"] |
| | gt_ans = re.sub(r"\(.*?\)", "", example["answer"]) |
| | elif data_name == "mawps": |
| | gt_cot, gt_ans = None, example["target"] |
| | elif data_name == "tabmwp": |
| | gt_cot = example["solution"] |
| | gt_ans = example["answer"] |
| | if example["ans_type"] in ["integer_number", "decimal_number"]: |
| | if "/" in gt_ans: |
| | gt_ans = int(gt_ans.split("/")[0]) / int(gt_ans.split("/")[1]) |
| | elif "," in gt_ans: |
| | gt_ans = float(gt_ans.replace(",", "")) |
| | elif "%" in gt_ans: |
| | gt_ans = float(gt_ans.split("%")[0]) / 100 |
| | else: |
| | gt_ans = float(gt_ans) |
| | elif data_name == "bbh": |
| | gt_cot, gt_ans = None, example["target"] |
| | else: |
| | raise NotImplementedError(data_name) |
| | |
| | gt_cot = str(gt_cot).strip() |
| | gt_ans = strip_string(gt_ans) |
| | return gt_cot, gt_ans |
| |
|
| |
|
| | def parse_question(example, data_name): |
| | question = "" |
| | if data_name == "asdiv": |
| | question = f"{example['body'].strip()} {example['question'].strip()}" |
| | elif data_name == "svamp": |
| | body = example["Body"].strip() |
| | if not body.endswith("."): |
| | body = body + "." |
| | question = f'{body} {example["Question"].strip()}' |
| | elif data_name == "tabmwp": |
| | title_str = ( |
| | f'regarding "{example["table_title"]}" ' if example["table_title"] else "" |
| | ) |
| | question = f"Read the following table {title_str}and answer a question:\n" |
| | question += f'{example["table"]}\n{example["question"]}' |
| | if example["choices"]: |
| | question += ( |
| | f' Please select from the following options: {example["choices"]}' |
| | ) |
| | else: |
| | for key in ["question", "problem", "Question", "input"]: |
| | if key in example: |
| | question = example[key] |
| | break |
| | assert question != "" |
| | return question.strip() |
| |
|
| |
|
| | def run_execute(executor, result, prompt_type, execute=False): |
| | if not result or result == "error": |
| | return None, None |
| | report = None |
| |
|
| | if "program_only" in prompt_type: |
| | prediction = extract_program_output(result) |
| | elif prompt_type in ["pot", "pal"] and execute: |
| | code = extract_program(result) |
| | prediction, report = executor.apply(code) |
| | else: |
| | prediction = extract_answer(result) |
| |
|
| | prediction = strip_string(prediction) |
| | return prediction, report |
| |
|
| |
|
| | def parse_digits(num): |
| | |
| | num = regex.sub(",", "", str(num)) |
| | try: |
| | return float(num) |
| | except: |
| | if num.endswith("%"): |
| | num = num[:-1] |
| | if num.endswith("\\"): |
| | num = num[:-1] |
| | try: |
| | return float(num) / 100 |
| | except: |
| | pass |
| | return None |
| |
|
| |
|
| | def is_digit(num): |
| | |
| | return parse_digits(num) is not None |
| |
|
| |
|
| | def normalize_prediction(prediction): |
| | try: |
| | if is_digit(prediction): |
| | prediction = np.round(float(str(prediction).replace(",", "")), 6) |
| | return str(prediction) |
| | except: |
| | pass |
| |
|
| | |
| | prediction = str(prediction).strip() |
| |
|
| | |
| | brackets = [] |
| | while ( |
| | prediction.startswith("[") |
| | and prediction.endswith("]") |
| | or (prediction.startswith("(") and prediction.endswith(")")) |
| | ): |
| | bracket = prediction[0] |
| | prediction = prediction[1:-1] |
| | if brackets and "," in prediction: |
| | pred_parts = [normalize_prediction(part) for part in prediction.split(",")] |
| | prediction = ",".join(pred_parts) |
| |
|
| | if brackets: |
| | for b in reversed(brackets): |
| | if b == "[": |
| | prediction = "[" + prediction + "]" |
| | else: |
| | assert b == "(" |
| | prediction = "(" + prediction + ")" |
| |
|
| | def _parse(s): |
| | for f in [parse_latex, parse_expr]: |
| | try: |
| | return f(s) |
| | except: |
| | pass |
| | return s |
| |
|
| | prediction = _parse(prediction) |
| |
|
| | for s in ["{", "}", "(", ")"]: |
| | prediction = prediction.replace(s, "") |
| |
|
| | return prediction |
| |
|
| |
|
| | def math_equal( |
| | prediction: Union[bool, float, str], |
| | reference: Union[float, str], |
| | include_percentage: bool = True, |
| | is_close: bool = True, |
| | timeout: bool = False, |
| | ) -> bool: |
| | """ |
| | Exact match of math if and only if: |
| | 1. numerical equal: both can convert to float and are equal |
| | 2. symbolic equal: both can convert to sympy expression and are equal |
| | """ |
| | if str(prediction) == str(reference): |
| | return True |
| |
|
| | try: |
| | if is_digit(prediction) and is_digit(reference): |
| | prediction = parse_digits(prediction) |
| | reference = parse_digits(reference) |
| | |
| | if include_percentage: |
| | gt_result = [reference / 100, reference, reference * 100] |
| | else: |
| | gt_result = [reference] |
| | for item in gt_result: |
| | try: |
| | if is_close: |
| | if isclose(item, prediction, abs_tol=1e-3): |
| | return True |
| | else: |
| | if item == prediction: |
| | return True |
| | except Exception: |
| | continue |
| | return False |
| | except: |
| | pass |
| |
|
| | if not prediction and prediction not in [0, False]: |
| | return False |
| |
|
| | |
| | reference = str(reference).strip() |
| | prediction = str(prediction).strip() |
| |
|
| | if ( |
| | regex.match(r"(\(|\[).+(\)|\])", prediction) is not None |
| | and regex.match(r"(\(|\[).+(\)|\])", reference) is not None |
| | ): |
| | pred_parts = prediction[1:-1].split(",") |
| | ref_parts = reference[1:-1].split(",") |
| | if len(pred_parts) == len(ref_parts): |
| | if all( |
| | [ |
| | math_equal( |
| | pred_parts[i], ref_parts[i], include_percentage, is_close |
| | ) |
| | for i in range(len(pred_parts)) |
| | ] |
| | ): |
| | return True |
| |
|
| | if ( |
| | ( |
| | prediction.startswith("\\begin{pmatrix}") |
| | or prediction.startswith("\\begin{bmatrix}") |
| | ) |
| | and ( |
| | prediction.endswith("\\end{pmatrix}") |
| | or prediction.endswith("\\end{bmatrix}") |
| | ) |
| | and ( |
| | reference.startswith("\\begin{pmatrix}") |
| | or reference.startswith("\\begin{bmatrix}") |
| | ) |
| | and ( |
| | reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}") |
| | ) |
| | ): |
| | pred_lines = [ |
| | line.strip() |
| | for line in prediction[ |
| | len("\\begin{pmatrix}") : -len("\\end{pmatrix}") |
| | ].split("\\\\") |
| | if line.strip() |
| | ] |
| | ref_lines = [ |
| | line.strip() |
| | for line in reference[ |
| | len("\\begin{pmatrix}") : -len("\\end{pmatrix}") |
| | ].split("\\\\") |
| | if line.strip() |
| | ] |
| | matched = True |
| | if len(pred_lines) == len(ref_lines): |
| | for pred_line, ref_line in zip(pred_lines, ref_lines): |
| | pred_parts = pred_line.split("&") |
| | ref_parts = ref_line.split("&") |
| | if len(pred_parts) == len(ref_parts): |
| | if not all( |
| | [ |
| | math_equal( |
| | pred_parts[i], |
| | ref_parts[i], |
| | include_percentage, |
| | is_close, |
| | ) |
| | for i in range(len(pred_parts)) |
| | ] |
| | ): |
| | matched = False |
| | break |
| | else: |
| | matched = False |
| | if not matched: |
| | break |
| | else: |
| | matched = False |
| | if matched: |
| | return True |
| |
|
| | if prediction.count("=") == 1 and reference.count("=") == 1: |
| | pred = prediction.split("=") |
| | pred = f"{pred[0].strip()} - ({pred[1].strip()})" |
| | ref = reference.split("=") |
| | ref = f"{ref[0].strip()} - ({ref[1].strip()})" |
| | if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref): |
| | return True |
| | elif ( |
| | prediction.count("=") == 1 |
| | and len(prediction.split("=")[0].strip()) <= 2 |
| | and "=" not in reference |
| | ): |
| | if math_equal( |
| | prediction.split("=")[1], reference, include_percentage, is_close |
| | ): |
| | return True |
| | elif ( |
| | reference.count("=") == 1 |
| | and len(reference.split("=")[0].strip()) <= 2 |
| | and "=" not in prediction |
| | ): |
| | if math_equal( |
| | prediction, reference.split("=")[1], include_percentage, is_close |
| | ): |
| | return True |
| |
|
| | |
| | if timeout: |
| | if call_with_timeout(symbolic_equal_process, prediction, reference): |
| | return True |
| | else: |
| | if symbolic_equal(prediction, reference): |
| | return True |
| |
|
| | return False |
| |
|
| |
|
| | def math_equal_process(param): |
| | return math_equal(param[-2], param[-1]) |
| |
|
| |
|
| | def symbolic_equal(a, b): |
| | def _parse(s): |
| | for f in [parse_latex, parse_expr]: |
| | try: |
| | return f(s) |
| | except: |
| | pass |
| | return s |
| |
|
| | a = _parse(a) |
| | b = _parse(b) |
| |
|
| | try: |
| | if simplify(a - b) == 0: |
| | return True |
| | except: |
| | pass |
| |
|
| | try: |
| | if isclose(N(a), N(b), abs_tol=1e-3): |
| | return True |
| | except: |
| | pass |
| | return False |
| |
|
| |
|
| | def symbolic_equal_process(a, b, output_queue): |
| | result = symbolic_equal(a, b) |
| | output_queue.put(result) |
| |
|
| |
|
| | def call_with_timeout(func, *args, timeout=1, **kwargs): |
| | output_queue = multiprocessing.Queue() |
| | process_args = args + (output_queue,) |
| | process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs) |
| | process.start() |
| | process.join(timeout) |
| |
|
| | if process.is_alive(): |
| | process.terminate() |
| | process.join() |
| | return False |
| |
|
| | return output_queue.get() |
| |
|