| | """ |
| | This file come from: https://github.com/microsoft/ToRA/blob/main/src/utils/parser.py |
| | """ |
| | import re |
| | from typing import Any, Dict |
| |
|
| |
|
| | def _fix_fracs(string): |
| | substrs = string.split("\\frac") |
| | new_str = substrs[0] |
| | if len(substrs) > 1: |
| | substrs = substrs[1:] |
| | for substr in substrs: |
| | new_str += "\\frac" |
| | if len(substr) > 0 and substr[0] == "{": |
| | new_str += substr |
| | else: |
| | try: |
| | assert len(substr) >= 2 |
| | except: |
| | return string |
| | a = substr[0] |
| | b = substr[1] |
| | if b != "{": |
| | if len(substr) > 2: |
| | post_substr = substr[2:] |
| | new_str += "{" + a + "}{" + b + "}" + post_substr |
| | else: |
| | new_str += "{" + a + "}{" + b + "}" |
| | else: |
| | if len(substr) > 2: |
| | post_substr = substr[2:] |
| | new_str += "{" + a + "}" + b + post_substr |
| | else: |
| | new_str += "{" + a + "}" + b |
| | string = new_str |
| | return string |
| |
|
| |
|
| | def _fix_a_slash_b(string): |
| | if len(string.split("/")) != 2: |
| | return string |
| | a = string.split("/")[0] |
| | b = string.split("/")[1] |
| | try: |
| | if "sqrt" not in a: |
| | a = int(a) |
| | if "sqrt" not in b: |
| | b = int(b) |
| | assert string == "{}/{}".format(a, b) |
| | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" |
| | return new_string |
| | except: |
| | return string |
| |
|
| |
|
| | def _fix_sqrt(string): |
| | _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string) |
| | return _string |
| |
|
| |
|
| | def strip_string(string): |
| | string = str(string).strip() |
| | |
| | string = string.replace("\n", "") |
| |
|
| | |
| | string = string.rstrip(".") |
| |
|
| | |
| | string = string.replace("\\!", "") |
| | string = string.replace("\\ ", "") |
| |
|
| | |
| | string = string.replace("\\\\", "\\") |
| | string = string.replace("\\\\", "\\") |
| |
|
| | |
| | string = string.replace("tfrac", "frac") |
| | string = string.replace("dfrac", "frac") |
| |
|
| | |
| | string = string.replace("\\left", "") |
| | string = string.replace("\\right", "") |
| |
|
| | |
| | _string = re.sub(r"\\text{.*?}$", "", string).strip() |
| | if _string != "" and _string != string: |
| | |
| | string = _string |
| |
|
| | |
| | string = string.replace("^{\\circ}", "") |
| | string = string.replace("^\\circ", "") |
| |
|
| | |
| | string = string.replace("\\$", "") |
| | string = string.replace("$", "") |
| |
|
| | string = string.replace("\\text", "") |
| | string = string.replace("x\\in", "") |
| |
|
| | |
| | string = string.replace("\\%", "") |
| | string = string.replace("\%", "") |
| | string = string.replace("%", "") |
| |
|
| | |
| | string = string.replace(" .", " 0.") |
| | string = string.replace("{.", "{0.") |
| |
|
| | |
| | string = string.replace("\\cdot", "") |
| |
|
| | |
| | string = string.replace("infinity", "\\infty") |
| | if "\\infty" not in string: |
| | string = string.replace("inf", "\\infty") |
| | string = string.replace("+\\inity", "\\infty") |
| |
|
| | |
| | string = string.replace("and", "") |
| | string = string.replace("\\mathbf", "") |
| |
|
| | |
| | string = re.sub(r"\\mbox{.*?}", "", string) |
| |
|
| | |
| | string.replace("'", "") |
| | string.replace("\"", "") |
| | |
| | |
| | if "j" in string and "i" not in string: |
| | string = string.replace("j", "i") |
| |
|
| | |
| | string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string) |
| | string = re.sub(r"(\d+)\.0+$", r"\1", string) |
| |
|
| | |
| | if len(string) == 0: |
| | return string |
| | if string[0] == ".": |
| | string = "0" + string |
| |
|
| | |
| | if len(string.split("=")) == 2: |
| | if len(string.split("=")[0]) <= 2: |
| | string = string.split("=")[1] |
| |
|
| | string = _fix_sqrt(string) |
| | string = string.replace(" ", "") |
| |
|
| | |
| | string = _fix_fracs(string) |
| |
|
| | |
| | string = _fix_a_slash_b(string) |
| |
|
| | return string |
| |
|
| | def extract_answer(pred_str): |
| | if 'boxed' in pred_str: |
| | ans = pred_str.split('boxed')[-1] |
| | if len(ans) == 0: |
| | return "" |
| | elif (ans[0] == '{'): |
| | stack = 1 |
| | a = '' |
| | for c in ans[1:]: |
| | if (c == '{'): |
| | stack += 1 |
| | a += c |
| | elif (c == '}'): |
| | stack -= 1 |
| | if (stack == 0): break |
| | a += c |
| | else: |
| | a += c |
| | else: |
| | a = ans.split('$')[0].strip() |
| | pred=a |
| | elif ('he answer is' in pred_str): |
| | pred = pred_str.split('he answer is')[-1].strip() |
| | elif extract_program_output(pred_str) != "": |
| | |
| | pred = extract_program_output(pred_str) |
| | else: |
| | pattern = '-?\d*\.?\d+' |
| | pred = re.findall(pattern, pred_str.replace(",", "")) |
| | if(len(pred) >= 1): |
| | pred = pred[-1] |
| | else: pred = '' |
| | |
| | |
| | pred = pred.split("\n")[0] |
| | if pred != "" and pred[0] == ":": |
| | pred = pred[1:] |
| | if pred != "" and pred[-1] == ".": |
| | pred = pred[:-1] |
| | if pred != "" and pred[-1] == "/": |
| | pred = pred[:-1] |
| | pred = strip_string(pred) |
| | return pred |
| |
|
| |
|
| | def extract_program(result: str, last_only=True): |
| | """ |
| | extract the program after "```python", and before "```" |
| | """ |
| | program = "" |
| | start = False |
| | for line in result.split("\n"): |
| | if line.startswith("```python"): |
| | if last_only: |
| | program = "" |
| | else: |
| | program += "\n# ========\n" |
| | start = True |
| | elif line.startswith("```"): |
| | start = False |
| | elif start: |
| | program += line + "\n" |
| | return program |
| |
|
| |
|
| | def extract_program_output(pred_str): |
| | """ |
| | extract output between the last ```output\n...\n``` |
| | """ |
| | if "```output" not in pred_str: |
| | return "" |
| | if '```output' in pred_str: |
| | pred_str = pred_str.split('```output')[-1] |
| | if '```' in pred_str: |
| | pred_str = pred_str.split('```')[0] |
| | output = pred_str.strip() |
| | return output |
| |
|
| |
|
| | def parse_ground_truth(example: Dict[str, Any], data_name): |
| | if 'gt_cot' in example: |
| | return example['gt_cot'], strip_string(example['gt']) |
| |
|
| | |
| | if data_name in ["math", 'ocw']: |
| | gt_cot = example['solution'] |
| | gt_ans = extract_answer(gt_cot) |
| | elif data_name == "gsm8k": |
| | gt_cot, gt_ans = example['answer'].split("####") |
| | elif data_name == "gsm-hard": |
| | gt_cot, gt_ans = example['code'], example['target'] |
| | elif data_name == "svamp": |
| | gt_cot, gt_ans = example['Equation'], example['Answer'] |
| | elif data_name == "asdiv": |
| | gt_cot = example['formula'] |
| | gt_ans = re.sub(r"\(.*?\)", "", example['answer']) |
| | elif data_name == "mawps": |
| | gt_cot, gt_ans = None, example['target'] |
| | elif data_name == "tabmwp": |
| | gt_cot = example['solution'] |
| | gt_ans = example['answer'] |
| | if example['ans_type'] in ['integer_number', 'decimal_number']: |
| | if '/' in gt_ans: |
| | gt_ans = int(gt_ans.split('/')[0]) / int(gt_ans.split('/')[1]) |
| | elif ',' in gt_ans: |
| | gt_ans = float(gt_ans.replace(',', '')) |
| | elif '%' in gt_ans: |
| | gt_ans = float(gt_ans.split('%')[0]) / 100 |
| | else: |
| | gt_ans = float(gt_ans) |
| | elif data_name == "bbh": |
| | gt_cot, gt_ans = None, example['target'] |
| | else: |
| | raise NotImplementedError(data_name) |
| | |
| | gt_cot = str(gt_cot).strip() |
| | gt_ans = strip_string(gt_ans) |
| | return gt_cot, gt_ans |
| |
|
| |
|
| | def parse_question(example, data_name): |
| | question = "" |
| | if data_name == "asdiv": |
| | question = f"{example['body'].strip()} {example['question'].strip()}" |
| | elif data_name == "svamp": |
| | body = example["Body"].strip() |
| | if not body.endswith("."): |
| | body = body + "." |
| | question = f'{body} {example["Question"].strip()}' |
| | elif data_name == "tabmwp": |
| | title_str = f'regarding "{example["table_title"]}" ' if example['table_title'] else "" |
| | question = f'Read the following table {title_str}and answer a question:\n' |
| | question += f'{example["table"]}\n{example["question"]}' |
| | if example['choices']: |
| | question += f' Please select from the following options: {example["choices"]}' |
| | else: |
| | for key in ['question', 'problem', 'Question', 'input']: |
| | if key in example: |
| | question = example[key] |
| | break |
| | assert question != "" |
| | return question.strip() |
| |
|
| |
|
| | def run_execute(executor, result, prompt_type, execute=False): |
| | if not result or result == 'error': |
| | return None, None |
| | report = None |
| |
|
| | if "program_only" in prompt_type: |
| | prediction = extract_program_output(result) |
| | elif prompt_type in ["pot", "pal"] and execute: |
| | code = extract_program(result) |
| | prediction, report = executor.apply(code) |
| | else: |
| | prediction = extract_answer(result) |
| |
|
| | prediction = strip_string(prediction) |
| | return prediction, report |
| |
|