|
import re |
|
import regex |
|
|
|
|
|
def _fix_fracs(string): |
|
substrs = string.split("\\frac") |
|
new_str = substrs[0] |
|
if len(substrs) > 1: |
|
substrs = substrs[1:] |
|
for substr in substrs: |
|
new_str += "\\frac" |
|
if len(substr) > 0 and substr[0] == "{": |
|
new_str += substr |
|
else: |
|
try: |
|
assert len(substr) >= 2 |
|
except: |
|
return string |
|
a = substr[0] |
|
b = substr[1] |
|
if b != "{": |
|
if len(substr) > 2: |
|
post_substr = substr[2:] |
|
new_str += "{" + a + "}{" + b + "}" + post_substr |
|
else: |
|
new_str += "{" + a + "}{" + b + "}" |
|
else: |
|
if len(substr) > 2: |
|
post_substr = substr[2:] |
|
new_str += "{" + a + "}" + b + post_substr |
|
else: |
|
new_str += "{" + a + "}" + b |
|
string = new_str |
|
return string |
|
|
|
|
|
def _fix_a_slash_b(string): |
|
if len(string.split("/")) != 2: |
|
return string |
|
a = string.split("/")[0] |
|
b = string.split("/")[1] |
|
try: |
|
if "sqrt" not in a: |
|
a = int(a) |
|
if "sqrt" not in b: |
|
b = int(b) |
|
assert string == "{}/{}".format(a, b) |
|
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" |
|
return new_string |
|
except: |
|
return string |
|
|
|
|
|
def _fix_sqrt(string): |
|
_string = re.sub(r"\\sqrt(-?[0-9.a-zA-Z]+)", r"\\sqrt{\1}", string) |
|
_string = re.sub(r"\\sqrt\s+(\w+)$", r"\\sqrt{\1}", _string) |
|
return _string |
|
|
|
|
|
def _fix_tan(string): |
|
_string = re.sub(r"\\tan(-?[0-9.a-zA-Z]+)", r"\\tan{\1}", string) |
|
_string = re.sub(r"\\tan\s+(\w+)$", r"\\tan{\1}", _string) |
|
return _string |
|
|
|
|
|
def strip_string(string): |
|
string = str(string).strip() |
|
|
|
string = string.replace("\n", "") |
|
|
|
|
|
string = string.rstrip(".") |
|
|
|
|
|
string = string.replace("\\!", "") |
|
|
|
|
|
|
|
|
|
|
|
|
|
if string.startswith("\\text{") and string.endswith("}"): |
|
string = string.split("{", 1)[1][:-1] |
|
|
|
|
|
string = string.replace("tfrac", "frac") |
|
string = string.replace("dfrac", "frac") |
|
string = string.replace("cfrac", "frac") |
|
|
|
|
|
string = string.replace("\\left", "") |
|
string = string.replace("\\right", "") |
|
|
|
|
|
_string = re.sub(r"\\text{.*?}$", "", string).strip() |
|
if _string != "" and _string != string: |
|
|
|
string = _string |
|
|
|
|
|
string = string.replace("^{\\circ}", "").strip() |
|
string = string.replace("^\\circ", "").strip() |
|
|
|
string = regex.sub(r"\{(c|m)?m\}(\^(2|3))?", "", string).strip() |
|
string = regex.sub(r"p\.m\.$", "", string).strip() |
|
string = regex.sub(r"(\d)\s*t$", r"\1", string).strip() |
|
|
|
|
|
string = string.replace("\\$", "") |
|
string = string.replace("$", "") |
|
|
|
|
|
string = string.replace("x\\in", "") |
|
|
|
|
|
string = string.replace("\\%", "%") |
|
string = string.replace("\%", "%") |
|
|
|
|
|
|
|
string = string.replace(" .", " 0.") |
|
string = string.replace("{.", "{0.") |
|
|
|
|
|
string = string.replace("\\cdot", "") |
|
|
|
|
|
string = string.replace("infinity", "\\infty") |
|
if "\\infty" not in string: |
|
string = string.replace("inf", "\\infty") |
|
string = string.replace("+\\inity", "\\infty") |
|
|
|
|
|
|
|
string = string.replace("\\mathbf", "") |
|
string = string.replace("\\mathrm", "") |
|
|
|
|
|
string = re.sub(r"\\mbox{.*?}", "", string) |
|
|
|
|
|
string.replace("'", "") |
|
string.replace('"', "") |
|
|
|
|
|
if "j" in string and "i" not in string: |
|
string = string.replace("j", "i") |
|
|
|
|
|
string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string) |
|
string = re.sub(r"(\d+)\.0+$", r"\1", string) |
|
|
|
|
|
if len(string) == 0: |
|
return string |
|
if string[0] == ".": |
|
string = "0" + string |
|
|
|
|
|
|
|
|
|
|
|
|
|
string = _fix_sqrt(string) |
|
string = _fix_tan(string) |
|
string = string.replace(" ", "") |
|
|
|
|
|
string = _fix_fracs(string) |
|
|
|
|
|
string = _fix_a_slash_b(string) |
|
|
|
string = regex.sub(r"(\\|,|\.)+$", "", string) |
|
|
|
return string |
|
|
|
|
|
def extract_boxed_answers(text): |
|
answers = [] |
|
for piece in text.split("boxed{")[1:]: |
|
n = 0 |
|
for i in range(len(piece)): |
|
if piece[i] == "{": |
|
n += 1 |
|
elif piece[i] == "}": |
|
n -= 1 |
|
if n < 0: |
|
if i + 1 < len(piece) and piece[i + 1] == "%": |
|
answers.append(piece[: i + 1]) |
|
else: |
|
answers.append(piece[:i]) |
|
break |
|
return answers |
|
|
|
|
|
def extract_program_output(pred_str): |
|
""" |
|
extract output between the last ```output\n...\n``` |
|
""" |
|
if "```output" not in pred_str: |
|
return "" |
|
if "```output" in pred_str: |
|
pred_str = pred_str.split("```output")[-1] |
|
if "```" in pred_str: |
|
pred_str = pred_str.split("```")[0] |
|
output = pred_str.strip() |
|
return output |
|
|
|
|
|
def extract_answer(pred_str, exhaust=False): |
|
pred = [] |
|
if "final answer is $" in pred_str and "$. I hope" in pred_str: |
|
tmp = pred_str.split("final answer is $", 1)[1] |
|
pred = [tmp.split("$. I hope", 1)[0].strip()] |
|
elif "boxed" in pred_str: |
|
pred = extract_boxed_answers(pred_str) |
|
elif "he answer is" in pred_str: |
|
pred = [pred_str.split("he answer is")[-1].strip()] |
|
else: |
|
program_output = extract_program_output(pred_str) |
|
if program_output != "": |
|
|
|
pred.append(program_output) |
|
else: |
|
pattern = "-?\d*\.?\d+" |
|
ans = re.findall(pattern, pred_str.replace(",", "")) |
|
if len(ans) >= 1: |
|
ans = ans[-1] |
|
else: |
|
ans = "" |
|
if ans: |
|
pred.append(ans) |
|
|
|
|
|
_pred = [] |
|
for ans in pred: |
|
ans = ans.strip().split("\n")[0] |
|
ans = ans.lstrip(":") |
|
ans = ans.rstrip(".") |
|
ans = ans.rstrip("/") |
|
ans = strip_string(ans) |
|
_pred.append(ans) |
|
if exhaust: |
|
return _pred |
|
else: |
|
return _pred[-1] if _pred else "" |
|
|
|
|
|
def extract_math_answer(question, reasoning, task): |
|
answer = [] |
|
for ans in extract_answer(reasoning, exhaust=True): |
|
if "separated by commas" in question and all(ch not in ans for ch in "()[]"): |
|
answer.extend([a.strip() for a in ans.split(",")]) |
|
elif regex.search(r"\\text\{\s*and\s*\}", ans): |
|
answer.extend( |
|
[ |
|
a.strip() |
|
for a in regex.sub(r"\\text\{\s*and\s*\}", "[SEP]", ans).split( |
|
"[SEP]" |
|
) |
|
] |
|
) |
|
else: |
|
answer.append(ans.strip()) |
|
return answer |
|
|
|
|
|
def extract_math_few_shot_cot_answer(question, reasoning, task): |
|
if "Problem:" in reasoning: |
|
reasoning = reasoning.split("Problem:", 1)[0] |
|
return extract_math_answer(question, reasoning, task) |
|
|
|
|
|
def extract_last_single_answer(question, reasoning, task): |
|
return extract_answer(reasoning, exhaust=False) |
|
|
|
|
|
def extract_gsm_few_shot_cot_answer(question, reasoning, task): |
|
if "Q: " in reasoning: |
|
reasoning = reasoning.split("Q: ", 1)[0] |
|
pred = [s for s in regex.findall(r"-?\d+\.?\d*", reasoning)] |
|
if pred: |
|
return pred[-1] |
|
else: |
|
return "[invalid]" |
|
|
|
|
|
def extract_agieval_gaokao_mathcloze_few_shot_cot_test(question, reasoning, task): |
|
if "问题 " in reasoning: |
|
reasoning = reasoning.split("问题 ", 1)[0] |
|
if "答案是" in reasoning: |
|
ans = reasoning.split("答案是", 1)[1].strip() |
|
ans = ans.split("\n")[0].strip() |
|
ans = [ans.strip("$")] |
|
else: |
|
ans = ["placeholder"] |
|
return ans |
|
|
|
|
|
def extract_agieval_gaokao_mathqa_few_shot_cot_test(question, reasoning, task): |
|
if "问题 " in reasoning: |
|
reasoning = reasoning.split("问题 ", 1)[0] |
|
if "答案是" in reasoning: |
|
ans = reasoning.split("答案是", 1)[1].strip() |
|
ans = ans.split("\n")[0].strip() |
|
else: |
|
ans = "placeholder" |
|
return ans |
|
|
|
|
|
def extract_sat_few_shot_answer(question, reasoning, task): |
|
if "Problem:" in reasoning: |
|
reasoning = reasoning.split("Problem:", 1)[0] |
|
patt = regex.search(r"the final answer is \(?(?P<ans>[abcd])\)?", reasoning.lower()) |
|
if patt is not None: |
|
return patt.group("ans").upper() |
|
return "placeholder" |
|
|
|
|
|
def extract_ocwcourses_few_shot_answer(question, reasoning, task): |
|
if "Problem:" in reasoning: |
|
reasoning = reasoning.split("Problem:", 1)[0] |
|
patt = regex.search( |
|
r"final answer is (?P<ans>.*)\. I hope it is correct.", reasoning |
|
) |
|
if patt is None: |
|
pred = "[invalid]" |
|
print(f"DEBUG >>>\n{reasoning}", flush=True) |
|
else: |
|
pred = patt.group("ans") |
|
return pred |
|
|
|
|
|
def extract_mmlu_stem(question, reasoning, task): |
|
if "Problem:" in reasoning: |
|
reasoning = reasoning.split("Problem:", 1)[0] |
|
return extract_sat_few_shot_answer(question, reasoning, task) |
|
|
|
|
|
def extract_minif2f_isabelle(question, reasoning, task): |
|
if "Informal:" in reasoning: |
|
reasoning = reasoning.split("Informal:", 1)[0] |
|
return reasoning.strip() |
|
|
|
|
|
def extract_cmath_few_shot_test(question, reasoning, task): |
|
if "问题:" in reasoning: |
|
reasoning = reasoning.split("问题:", 1)[0] |
|
if "答案是" in reasoning: |
|
ans = reasoning.split("答案是", 1)[1].strip() |
|
ans = ans.split("\n")[0] |
|
ans = ans.strip(":") |
|
ans = ans.strip("。") |
|
try: |
|
ans = [s for s in regex.findall(r"-?\d+\.?\d*", ans)][-1] |
|
except: |
|
print(f"DEBUG CMATH: {reasoning}", flush=True) |
|
ans = "[invalid]" |
|
else: |
|
ans = extract_last_single_answer(question, reasoning, task) |
|
return ans |
|
|