RPC / eval /eval_utils.py
WNJXYK's picture
Upload 16 files
22c93a7 verified
import multiprocessing
from math import isclose
import numpy as np
from typing import Union, Any, Dict
from sympy import simplify, N
from sympy.parsing.sympy_parser import parse_expr
from sympy.parsing.latex import parse_latex
import re
import regex
from data_processing.answer_extraction import (
extract_answer,
extract_program_output,
strip_string,
)
def extract_program(result: str, last_only=True):
"""
extract the program after "```python", and before "```"
"""
program = ""
start = False
for line in result.split("\n"):
if line.startswith("```python"):
if last_only:
program = "" # only extract the last program
else:
program += "\n# ========\n"
start = True
elif line.startswith("```"):
start = False
elif start:
program += line + "\n"
return program
def parse_ground_truth(example: Dict[str, Any], data_name):
if "gt_cot" in example:
return example["gt_cot"], strip_string(example["gt"])
# parse ground truth
if data_name in ["math", "ocw"]:
gt_cot = example["solution"]
gt_ans = extract_answer(gt_cot)
elif data_name == "gsm8k":
gt_cot, gt_ans = example["answer"].split("####")
elif data_name == "gsm-hard":
gt_cot, gt_ans = example["code"], example["target"]
elif data_name == "svamp":
gt_cot, gt_ans = example["Equation"], example["Answer"]
elif data_name == "asdiv":
gt_cot = example["formula"]
gt_ans = re.sub(r"\(.*?\)", "", example["answer"])
elif data_name == "mawps":
gt_cot, gt_ans = None, example["target"]
elif data_name == "tabmwp":
gt_cot = example["solution"]
gt_ans = example["answer"]
if example["ans_type"] in ["integer_number", "decimal_number"]:
if "/" in gt_ans:
gt_ans = int(gt_ans.split("/")[0]) / int(gt_ans.split("/")[1])
elif "," in gt_ans:
gt_ans = float(gt_ans.replace(",", ""))
elif "%" in gt_ans:
gt_ans = float(gt_ans.split("%")[0]) / 100
else:
gt_ans = float(gt_ans)
elif data_name == "bbh":
gt_cot, gt_ans = None, example["target"]
else:
raise NotImplementedError(data_name)
# post process
gt_cot = str(gt_cot).strip()
gt_ans = strip_string(gt_ans)
return gt_cot, gt_ans
def parse_question(example, data_name):
question = ""
if data_name == "asdiv":
question = f"{example['body'].strip()} {example['question'].strip()}"
elif data_name == "svamp":
body = example["Body"].strip()
if not body.endswith("."):
body = body + "."
question = f'{body} {example["Question"].strip()}'
elif data_name == "tabmwp":
title_str = (
f'regarding "{example["table_title"]}" ' if example["table_title"] else ""
)
question = f"Read the following table {title_str}and answer a question:\n"
question += f'{example["table"]}\n{example["question"]}'
if example["choices"]:
question += (
f' Please select from the following options: {example["choices"]}'
)
else:
for key in ["question", "problem", "Question", "input"]:
if key in example:
question = example[key]
break
assert question != ""
return question.strip()
def run_execute(executor, result, prompt_type, execute=False):
if not result or result == "error":
return None, None
report = None
if "program_only" in prompt_type:
prediction = extract_program_output(result)
elif prompt_type in ["pot", "pal"] and execute:
code = extract_program(result)
prediction, report = executor.apply(code)
else:
prediction = extract_answer(result)
prediction = strip_string(prediction)
return prediction, report
def parse_digits(num):
# format: 234.23 || 23%
num = regex.sub(",", "", str(num))
try:
return float(num)
except:
if num.endswith("%"):
num = num[:-1]
if num.endswith("\\"):
num = num[:-1]
try:
return float(num) / 100
except:
pass
return None
def is_digit(num):
# paired with parse_digits
return parse_digits(num) is not None
def normalize_prediction(prediction):
try: # 1. numerical equal
if is_digit(prediction):
prediction = np.round(float(str(prediction).replace(",", "")), 6)
return str(prediction)
except:
pass
# 2. symbolic equal
prediction = str(prediction).strip()
## deal with [], (), {}
brackets = []
while (
prediction.startswith("[")
and prediction.endswith("]")
or (prediction.startswith("(") and prediction.endswith(")"))
):
bracket = prediction[0]
prediction = prediction[1:-1]
if brackets and "," in prediction:
pred_parts = [normalize_prediction(part) for part in prediction.split(",")]
prediction = ",".join(pred_parts)
if brackets:
for b in reversed(brackets):
if b == "[":
prediction = "[" + prediction + "]"
else:
assert b == "("
prediction = "(" + prediction + ")"
def _parse(s):
for f in [parse_latex, parse_expr]:
try:
return f(s)
except:
pass
return s
prediction = _parse(prediction)
for s in ["{", "}", "(", ")"]:
prediction = prediction.replace(s, "")
return prediction
def math_equal(
prediction: Union[bool, float, str],
reference: Union[float, str],
include_percentage: bool = True,
is_close: bool = True,
timeout: bool = False,
) -> bool:
"""
Exact match of math if and only if:
1. numerical equal: both can convert to float and are equal
2. symbolic equal: both can convert to sympy expression and are equal
"""
if str(prediction) == str(reference):
return True
try: # 1. numerical equal
if is_digit(prediction) and is_digit(reference):
prediction = parse_digits(prediction)
reference = parse_digits(reference)
# number questions
if include_percentage:
gt_result = [reference / 100, reference, reference * 100]
else:
gt_result = [reference]
for item in gt_result:
try:
if is_close:
if isclose(item, prediction, abs_tol=1e-3):
return True
else:
if item == prediction:
return True
except Exception:
continue
return False
except:
pass
if not prediction and prediction not in [0, False]:
return False
# 2. symbolic equal
reference = str(reference).strip()
prediction = str(prediction).strip()
if (
regex.match(r"(\(|\[).+(\)|\])", prediction) is not None
and regex.match(r"(\(|\[).+(\)|\])", reference) is not None
):
pred_parts = prediction[1:-1].split(",")
ref_parts = reference[1:-1].split(",")
if len(pred_parts) == len(ref_parts):
if all(
[
math_equal(
pred_parts[i], ref_parts[i], include_percentage, is_close
)
for i in range(len(pred_parts))
]
):
return True
if (
(
prediction.startswith("\\begin{pmatrix}")
or prediction.startswith("\\begin{bmatrix}")
)
and (
prediction.endswith("\\end{pmatrix}")
or prediction.endswith("\\end{bmatrix}")
)
and (
reference.startswith("\\begin{pmatrix}")
or reference.startswith("\\begin{bmatrix}")
)
and (
reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")
)
):
pred_lines = [
line.strip()
for line in prediction[
len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
].split("\\\\")
if line.strip()
]
ref_lines = [
line.strip()
for line in reference[
len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
].split("\\\\")
if line.strip()
]
matched = True
if len(pred_lines) == len(ref_lines):
for pred_line, ref_line in zip(pred_lines, ref_lines):
pred_parts = pred_line.split("&")
ref_parts = ref_line.split("&")
if len(pred_parts) == len(ref_parts):
if not all(
[
math_equal(
pred_parts[i],
ref_parts[i],
include_percentage,
is_close,
)
for i in range(len(pred_parts))
]
):
matched = False
break
else:
matched = False
if not matched:
break
else:
matched = False
if matched:
return True
if prediction.count("=") == 1 and reference.count("=") == 1:
pred = prediction.split("=")
pred = f"{pred[0].strip()} - ({pred[1].strip()})"
ref = reference.split("=")
ref = f"{ref[0].strip()} - ({ref[1].strip()})"
if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
return True
elif (
prediction.count("=") == 1
and len(prediction.split("=")[0].strip()) <= 2
and "=" not in reference
):
if math_equal(
prediction.split("=")[1], reference, include_percentage, is_close
):
return True
elif (
reference.count("=") == 1
and len(reference.split("=")[0].strip()) <= 2
and "=" not in prediction
):
if math_equal(
prediction, reference.split("=")[1], include_percentage, is_close
):
return True
# symbolic equal with sympy
if timeout:
if call_with_timeout(symbolic_equal_process, prediction, reference):
return True
else:
if symbolic_equal(prediction, reference):
return True
return False
def math_equal_process(param):
return math_equal(param[-2], param[-1])
def symbolic_equal(a, b):
def _parse(s):
for f in [parse_latex, parse_expr]:
try:
return f(s)
except:
pass
return s
a = _parse(a)
b = _parse(b)
try:
if simplify(a - b) == 0:
return True
except:
pass
try:
if isclose(N(a), N(b), abs_tol=1e-3):
return True
except:
pass
return False
def symbolic_equal_process(a, b, output_queue):
result = symbolic_equal(a, b)
output_queue.put(result)
def call_with_timeout(func, *args, timeout=1, **kwargs):
output_queue = multiprocessing.Queue()
process_args = args + (output_queue,)
process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
process.start()
process.join(timeout)
if process.is_alive():
process.terminate()
process.join()
return False
return output_queue.get()