|
from typing import Dict, List |
|
|
|
import datasets |
|
|
|
|
|
def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: |
|
def _process_doc(doc: dict) -> dict: |
|
out_doc = { |
|
"problem": doc["problem"], |
|
"solution": doc["solution"], |
|
"answer": remove_boxed(last_boxed_only_string(doc["solution"])), |
|
} |
|
return out_doc |
|
|
|
return dataset.map(_process_doc) |
|
|
|
|
|
def process_results(doc: dict, results: List[str]) -> Dict[str, int]: |
|
retval = 0 |
|
indices = [pos for pos, char in enumerate(results[0]) if char == "$"] |
|
if len(indices) <= 1: |
|
answer = results[0] |
|
else: |
|
answer = results[0][indices[0] + 1 : indices[-1]] |
|
|
|
if is_equiv(answer, remove_boxed(last_boxed_only_string(doc["solution"]))): |
|
retval = 1 |
|
|
|
results = { |
|
"exact_match": retval, |
|
} |
|
return results |
|
|
|
|
|
|
|
def is_equiv(str1, str2, verbose=False): |
|
if str1 is None and str2 is None: |
|
print("WARNING: Both None") |
|
return True |
|
if str1 is None or str2 is None: |
|
return False |
|
|
|
try: |
|
ss1 = strip_string(str1) |
|
ss2 = strip_string(str2) |
|
if verbose: |
|
print(ss1, ss2) |
|
return ss1 == ss2 |
|
except Exception: |
|
return str1 == str2 |
|
|
|
|
|
def remove_boxed(s): |
|
if "\\boxed " in s: |
|
left = "\\boxed " |
|
assert s[: len(left)] == left |
|
return s[len(left) :] |
|
|
|
left = "\\boxed{" |
|
|
|
assert s[: len(left)] == left |
|
assert s[-1] == "}" |
|
|
|
return s[len(left) : -1] |
|
|
|
|
|
def last_boxed_only_string(string): |
|
idx = string.rfind("\\boxed") |
|
if "\\boxed " in string: |
|
return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] |
|
if idx < 0: |
|
idx = string.rfind("\\fbox") |
|
if idx < 0: |
|
return None |
|
|
|
i = idx |
|
right_brace_idx = None |
|
num_left_braces_open = 0 |
|
while i < len(string): |
|
if string[i] == "{": |
|
num_left_braces_open += 1 |
|
if string[i] == "}": |
|
num_left_braces_open -= 1 |
|
if num_left_braces_open == 0: |
|
right_brace_idx = i |
|
break |
|
i += 1 |
|
|
|
if right_brace_idx is None: |
|
retval = None |
|
else: |
|
retval = string[idx : right_brace_idx + 1] |
|
|
|
return retval |
|
|
|
|
|
def fix_fracs(string): |
|
substrs = string.split("\\frac") |
|
new_str = substrs[0] |
|
if len(substrs) > 1: |
|
substrs = substrs[1:] |
|
for substr in substrs: |
|
new_str += "\\frac" |
|
if substr[0] == "{": |
|
new_str += substr |
|
else: |
|
try: |
|
assert len(substr) >= 2 |
|
except AssertionError: |
|
return string |
|
a = substr[0] |
|
b = substr[1] |
|
if b != "{": |
|
if len(substr) > 2: |
|
post_substr = substr[2:] |
|
new_str += "{" + a + "}{" + b + "}" + post_substr |
|
else: |
|
new_str += "{" + a + "}{" + b + "}" |
|
else: |
|
if len(substr) > 2: |
|
post_substr = substr[2:] |
|
new_str += "{" + a + "}" + b + post_substr |
|
else: |
|
new_str += "{" + a + "}" + b |
|
string = new_str |
|
return string |
|
|
|
|
|
def fix_a_slash_b(string): |
|
if len(string.split("/")) != 2: |
|
return string |
|
a = string.split("/")[0] |
|
b = string.split("/")[1] |
|
try: |
|
a = int(a) |
|
b = int(b) |
|
assert string == "{}/{}".format(a, b) |
|
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" |
|
return new_string |
|
except AssertionError: |
|
return string |
|
|
|
|
|
def remove_right_units(string): |
|
|
|
if "\\text{ " in string: |
|
splits = string.split("\\text{ ") |
|
assert len(splits) == 2 |
|
return splits[0] |
|
else: |
|
return string |
|
|
|
|
|
def fix_sqrt(string): |
|
if "\\sqrt" not in string: |
|
return string |
|
splits = string.split("\\sqrt") |
|
new_string = splits[0] |
|
for split in splits[1:]: |
|
if split[0] != "{": |
|
a = split[0] |
|
new_substr = "\\sqrt{" + a + "}" + split[1:] |
|
else: |
|
new_substr = "\\sqrt" + split |
|
new_string += new_substr |
|
return new_string |
|
|
|
|
|
def strip_string(string): |
|
|
|
string = string.replace("\n", "") |
|
|
|
|
|
string = string.replace("\\!", "") |
|
|
|
|
|
string = string.replace("\\\\", "\\") |
|
|
|
|
|
string = string.replace("tfrac", "frac") |
|
string = string.replace("dfrac", "frac") |
|
|
|
|
|
string = string.replace("\\left", "") |
|
string = string.replace("\\right", "") |
|
|
|
|
|
string = string.replace("^{\\circ}", "") |
|
string = string.replace("^\\circ", "") |
|
|
|
|
|
string = string.replace("\\$", "") |
|
|
|
|
|
string = remove_right_units(string) |
|
|
|
|
|
string = string.replace("\\%", "") |
|
string = string.replace("\%", "") |
|
|
|
|
|
string = string.replace(" .", " 0.") |
|
string = string.replace("{.", "{0.") |
|
|
|
if len(string) == 0: |
|
return string |
|
if string[0] == ".": |
|
string = "0" + string |
|
|
|
|
|
if len(string.split("=")) == 2: |
|
if len(string.split("=")[0]) <= 2: |
|
string = string.split("=")[1] |
|
|
|
|
|
string = fix_sqrt(string) |
|
|
|
|
|
string = string.replace(" ", "") |
|
|
|
|
|
string = fix_fracs(string) |
|
|
|
|
|
if string == "0.5": |
|
string = "\\frac{1}{2}" |
|
|
|
|
|
string = fix_a_slash_b(string) |
|
|
|
return string |
|
|