| | import re |
| | import regex |
| |
|
| |
|
| | def _fix_fracs(string): |
| | substrs = string.split("\\frac") |
| | new_str = substrs[0] |
| | if len(substrs) > 1: |
| | substrs = substrs[1:] |
| | for substr in substrs: |
| | new_str += "\\frac" |
| | if len(substr) > 0 and substr[0] == "{": |
| | new_str += substr |
| | else: |
| | try: |
| | assert len(substr) >= 2 |
| | except: |
| | return string |
| | a = substr[0] |
| | b = substr[1] |
| | if b != "{": |
| | if len(substr) > 2: |
| | post_substr = substr[2:] |
| | new_str += "{" + a + "}{" + b + "}" + post_substr |
| | else: |
| | new_str += "{" + a + "}{" + b + "}" |
| | else: |
| | if len(substr) > 2: |
| | post_substr = substr[2:] |
| | new_str += "{" + a + "}" + b + post_substr |
| | else: |
| | new_str += "{" + a + "}" + b |
| | string = new_str |
| | return string |
| |
|
| |
|
| | def _fix_a_slash_b(string): |
| | if len(string.split("/")) != 2: |
| | return string |
| | a = string.split("/")[0] |
| | b = string.split("/")[1] |
| | try: |
| | if "sqrt" not in a: |
| | a = int(a) |
| | if "sqrt" not in b: |
| | b = int(b) |
| | assert string == "{}/{}".format(a, b) |
| | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" |
| | return new_string |
| | except: |
| | return string |
| |
|
| |
|
| | def _fix_sqrt(string): |
| | _string = re.sub(r"\\sqrt(-?[0-9.a-zA-Z]+)", r"\\sqrt{\1}", string) |
| | _string = re.sub(r"\\sqrt\s+(\w+)$", r"\\sqrt{\1}", _string) |
| | return _string |
| |
|
| |
|
| | def _fix_tan(string): |
| | _string = re.sub(r"\\tan(-?[0-9.a-zA-Z]+)", r"\\tan{\1}", string) |
| | _string = re.sub(r"\\tan\s+(\w+)$", r"\\tan{\1}", _string) |
| | return _string |
| |
|
| |
|
| | def strip_string(string): |
| | string = str(string).strip() |
| | |
| | string = string.replace("\n", "") |
| |
|
| | |
| | string = string.rstrip(".") |
| |
|
| | |
| | string = string.replace("\\!", "") |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | if string.startswith("\\text{") and string.endswith("}"): |
| | string = string.split("{", 1)[1][:-1] |
| |
|
| | |
| | string = string.replace("tfrac", "frac") |
| | string = string.replace("dfrac", "frac") |
| | string = string.replace("cfrac", "frac") |
| |
|
| | |
| | string = string.replace("\\left", "") |
| | string = string.replace("\\right", "") |
| |
|
| | |
| | _string = re.sub(r"\\text{.*?}$", "", string).strip() |
| | if _string != "" and _string != string: |
| | |
| | string = _string |
| |
|
| | |
| | string = string.replace("^{\\circ}", "").strip() |
| | string = string.replace("^\\circ", "").strip() |
| |
|
| | string = regex.sub(r"\{(c|m)?m\}(\^(2|3))?", "", string).strip() |
| | string = regex.sub(r"p\.m\.$", "", string).strip() |
| | string = regex.sub(r"(\d)\s*t$", r"\1", string).strip() |
| |
|
| | |
| | string = string.replace("\\$", "") |
| | string = string.replace("$", "") |
| |
|
| | |
| | string = string.replace("x\\in", "") |
| |
|
| | |
| | string = string.replace("\\%", "%") |
| | string = string.replace("\%", "%") |
| | |
| |
|
| | |
| | string = string.replace(" .", " 0.") |
| | string = string.replace("{.", "{0.") |
| |
|
| | |
| | string = string.replace("\\cdot", "") |
| |
|
| | |
| | string = string.replace("infinity", "\\infty") |
| | if "\\infty" not in string: |
| | string = string.replace("inf", "\\infty") |
| | string = string.replace("+\\inity", "\\infty") |
| |
|
| | |
| | |
| | string = string.replace("\\mathbf", "") |
| | string = string.replace("\\mathrm", "") |
| |
|
| | |
| | string = re.sub(r"\\mbox{.*?}", "", string) |
| |
|
| | |
| | string.replace("'", "") |
| | string.replace('"', "") |
| |
|
| | |
| | if "j" in string and "i" not in string: |
| | string = string.replace("j", "i") |
| |
|
| | |
| | string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string) |
| | string = re.sub(r"(\d+)\.0+$", r"\1", string) |
| |
|
| | |
| | if len(string) == 0: |
| | return string |
| | if string[0] == ".": |
| | string = "0" + string |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | string = _fix_sqrt(string) |
| | string = _fix_tan(string) |
| | string = string.replace(" ", "") |
| |
|
| | |
| | string = _fix_fracs(string) |
| |
|
| | |
| | string = _fix_a_slash_b(string) |
| |
|
| | string = regex.sub(r"(\\|,|\.)+$", "", string) |
| |
|
| | return string |
| |
|
| |
|
| | def extract_boxed_answers(text): |
| | answers = [] |
| | for piece in text.split("boxed{")[1:]: |
| | n = 0 |
| | for i in range(len(piece)): |
| | if piece[i] == "{": |
| | n += 1 |
| | elif piece[i] == "}": |
| | n -= 1 |
| | if n < 0: |
| | if i + 1 < len(piece) and piece[i + 1] == "%": |
| | answers.append(piece[: i + 1]) |
| | else: |
| | answers.append(piece[:i]) |
| | break |
| | return answers |
| |
|
| |
|
| | def extract_program_output(pred_str): |
| | """ |
| | extract output between the last ```output\n...\n``` |
| | """ |
| | if "```output" not in pred_str: |
| | return "" |
| | if "```output" in pred_str: |
| | pred_str = pred_str.split("```output")[-1] |
| | if "```" in pred_str: |
| | pred_str = pred_str.split("```")[0] |
| | output = pred_str.strip() |
| | return output |
| |
|
| |
|
| | def extract_answer(pred_str, exhaust=False): |
| | pred = [] |
| | if "final answer is $" in pred_str and "$. I hope" in pred_str: |
| | tmp = pred_str.split("final answer is $", 1)[1] |
| | pred = [tmp.split("$. I hope", 1)[0].strip()] |
| | elif "boxed" in pred_str: |
| | pred = extract_boxed_answers(pred_str) |
| | elif "he answer is" in pred_str: |
| | pred = [pred_str.split("he answer is")[-1].strip()] |
| | else: |
| | program_output = extract_program_output(pred_str) |
| | if program_output != "": |
| | |
| | pred.append(program_output) |
| | else: |
| | pattern = "-?\d*\.?\d+" |
| | ans = re.findall(pattern, pred_str.replace(",", "")) |
| | if len(ans) >= 1: |
| | ans = ans[-1] |
| | else: |
| | ans = "" |
| | if ans: |
| | pred.append(ans) |
| |
|
| | |
| | _pred = [] |
| | for ans in pred: |
| | ans = ans.strip().split("\n")[0] |
| | ans = ans.lstrip(":") |
| | ans = ans.rstrip(".") |
| | ans = ans.rstrip("/") |
| | ans = strip_string(ans) |
| | _pred.append(ans) |
| | if exhaust: |
| | return _pred |
| | else: |
| | return _pred[-1] if _pred else "" |
| |
|
| |
|
| | def extract_math_answer(question, reasoning, task): |
| | answer = [] |
| | for ans in extract_answer(reasoning, exhaust=True): |
| | if "separated by commas" in question and all(ch not in ans for ch in "()[]"): |
| | answer.extend([a.strip() for a in ans.split(",")]) |
| | elif regex.search(r"\\text\{\s*and\s*\}", ans): |
| | answer.extend( |
| | [ |
| | a.strip() |
| | for a in regex.sub(r"\\text\{\s*and\s*\}", "[SEP]", ans).split( |
| | "[SEP]" |
| | ) |
| | ] |
| | ) |
| | else: |
| | answer.append(ans.strip()) |
| | return answer |
| |
|
| |
|
| | def extract_math_few_shot_cot_answer(question, reasoning, task): |
| | if "Problem:" in reasoning: |
| | reasoning = reasoning.split("Problem:", 1)[0] |
| | return extract_math_answer(question, reasoning, task) |
| |
|
| |
|
| | def extract_last_single_answer(question, reasoning, task): |
| | return extract_answer(reasoning, exhaust=False) |
| |
|
| |
|
| | def extract_gsm_few_shot_cot_answer(question, reasoning, task): |
| | if "Q: " in reasoning: |
| | reasoning = reasoning.split("Q: ", 1)[0] |
| | pred = [s for s in regex.findall(r"-?\d+\.?\d*", reasoning)] |
| | if pred: |
| | return pred[-1] |
| | else: |
| | return "[invalid]" |
| |
|
| |
|
| | def extract_agieval_gaokao_mathcloze_few_shot_cot_test(question, reasoning, task): |
| | if "问题 " in reasoning: |
| | reasoning = reasoning.split("问题 ", 1)[0] |
| | if "答案是" in reasoning: |
| | ans = reasoning.split("答案是", 1)[1].strip() |
| | ans = ans.split("\n")[0].strip() |
| | ans = [ans.strip("$")] |
| | else: |
| | ans = ["placeholder"] |
| | return ans |
| |
|
| |
|
| | def extract_agieval_gaokao_mathqa_few_shot_cot_test(question, reasoning, task): |
| | if "问题 " in reasoning: |
| | reasoning = reasoning.split("问题 ", 1)[0] |
| | if "答案是" in reasoning: |
| | ans = reasoning.split("答案是", 1)[1].strip() |
| | ans = ans.split("\n")[0].strip() |
| | else: |
| | ans = "placeholder" |
| | return ans |
| |
|
| |
|
| | def extract_sat_few_shot_answer(question, reasoning, task): |
| | if "Problem:" in reasoning: |
| | reasoning = reasoning.split("Problem:", 1)[0] |
| | patt = regex.search(r"the final answer is \(?(?P<ans>[abcd])\)?", reasoning.lower()) |
| | if patt is not None: |
| | return patt.group("ans").upper() |
| | return "placeholder" |
| |
|
| |
|
| | def extract_ocwcourses_few_shot_answer(question, reasoning, task): |
| | if "Problem:" in reasoning: |
| | reasoning = reasoning.split("Problem:", 1)[0] |
| | patt = regex.search( |
| | r"final answer is (?P<ans>.*)\. I hope it is correct.", reasoning |
| | ) |
| | if patt is None: |
| | pred = "[invalid]" |
| | print(f"DEBUG >>>\n{reasoning}", flush=True) |
| | else: |
| | pred = patt.group("ans") |
| | return pred |
| |
|
| |
|
| | def extract_mmlu_stem(question, reasoning, task): |
| | if "Problem:" in reasoning: |
| | reasoning = reasoning.split("Problem:", 1)[0] |
| | return extract_sat_few_shot_answer(question, reasoning, task) |
| |
|
| |
|
| | def extract_minif2f_isabelle(question, reasoning, task): |
| | if "Informal:" in reasoning: |
| | reasoning = reasoning.split("Informal:", 1)[0] |
| | return reasoning.strip() |
| |
|
| |
|
| | def extract_cmath_few_shot_test(question, reasoning, task): |
| | if "问题:" in reasoning: |
| | reasoning = reasoning.split("问题:", 1)[0] |
| | if "答案是" in reasoning: |
| | ans = reasoning.split("答案是", 1)[1].strip() |
| | ans = ans.split("\n")[0] |
| | ans = ans.strip(":") |
| | ans = ans.strip("。") |
| | try: |
| | ans = [s for s in regex.findall(r"-?\d+\.?\d*", ans)][-1] |
| | except: |
| | print(f"DEBUG CMATH: {reasoning}", flush=True) |
| | ans = "[invalid]" |
| | else: |
| | ans = extract_last_single_answer(question, reasoning, task) |
| | return ans |
| |
|