|
import multiprocessing |
|
from math import isclose |
|
import numpy as np |
|
from typing import Union, Any, Dict |
|
|
|
from sympy import simplify, N |
|
from sympy.parsing.sympy_parser import parse_expr |
|
from sympy.parsing.latex import parse_latex |
|
import re |
|
import regex |
|
|
|
from data_processing.answer_extraction import ( |
|
extract_answer, |
|
extract_program_output, |
|
strip_string, |
|
) |
|
|
|
|
|
def extract_program(result: str, last_only=True): |
|
""" |
|
extract the program after "```python", and before "```" |
|
""" |
|
program = "" |
|
start = False |
|
for line in result.split("\n"): |
|
if line.startswith("```python"): |
|
if last_only: |
|
program = "" |
|
else: |
|
program += "\n# ========\n" |
|
start = True |
|
elif line.startswith("```"): |
|
start = False |
|
elif start: |
|
program += line + "\n" |
|
return program |
|
|
|
|
|
def parse_ground_truth(example: Dict[str, Any], data_name): |
|
if "gt_cot" in example: |
|
return example["gt_cot"], strip_string(example["gt"]) |
|
|
|
|
|
if data_name in ["math", "ocw"]: |
|
gt_cot = example["solution"] |
|
gt_ans = extract_answer(gt_cot) |
|
elif data_name == "gsm8k": |
|
gt_cot, gt_ans = example["answer"].split("####") |
|
elif data_name == "gsm-hard": |
|
gt_cot, gt_ans = example["code"], example["target"] |
|
elif data_name == "svamp": |
|
gt_cot, gt_ans = example["Equation"], example["Answer"] |
|
elif data_name == "asdiv": |
|
gt_cot = example["formula"] |
|
gt_ans = re.sub(r"\(.*?\)", "", example["answer"]) |
|
elif data_name == "mawps": |
|
gt_cot, gt_ans = None, example["target"] |
|
elif data_name == "tabmwp": |
|
gt_cot = example["solution"] |
|
gt_ans = example["answer"] |
|
if example["ans_type"] in ["integer_number", "decimal_number"]: |
|
if "/" in gt_ans: |
|
gt_ans = int(gt_ans.split("/")[0]) / int(gt_ans.split("/")[1]) |
|
elif "," in gt_ans: |
|
gt_ans = float(gt_ans.replace(",", "")) |
|
elif "%" in gt_ans: |
|
gt_ans = float(gt_ans.split("%")[0]) / 100 |
|
else: |
|
gt_ans = float(gt_ans) |
|
elif data_name == "bbh": |
|
gt_cot, gt_ans = None, example["target"] |
|
else: |
|
raise NotImplementedError(data_name) |
|
|
|
gt_cot = str(gt_cot).strip() |
|
gt_ans = strip_string(gt_ans) |
|
return gt_cot, gt_ans |
|
|
|
|
|
def parse_question(example, data_name): |
|
question = "" |
|
if data_name == "asdiv": |
|
question = f"{example['body'].strip()} {example['question'].strip()}" |
|
elif data_name == "svamp": |
|
body = example["Body"].strip() |
|
if not body.endswith("."): |
|
body = body + "." |
|
question = f'{body} {example["Question"].strip()}' |
|
elif data_name == "tabmwp": |
|
title_str = ( |
|
f'regarding "{example["table_title"]}" ' if example["table_title"] else "" |
|
) |
|
question = f"Read the following table {title_str}and answer a question:\n" |
|
question += f'{example["table"]}\n{example["question"]}' |
|
if example["choices"]: |
|
question += ( |
|
f' Please select from the following options: {example["choices"]}' |
|
) |
|
else: |
|
for key in ["question", "problem", "Question", "input"]: |
|
if key in example: |
|
question = example[key] |
|
break |
|
assert question != "" |
|
return question.strip() |
|
|
|
|
|
def run_execute(executor, result, prompt_type, execute=False): |
|
if not result or result == "error": |
|
return None, None |
|
report = None |
|
|
|
if "program_only" in prompt_type: |
|
prediction = extract_program_output(result) |
|
elif prompt_type in ["pot", "pal"] and execute: |
|
code = extract_program(result) |
|
prediction, report = executor.apply(code) |
|
else: |
|
prediction = extract_answer(result) |
|
|
|
prediction = strip_string(prediction) |
|
return prediction, report |
|
|
|
|
|
def parse_digits(num): |
|
|
|
num = regex.sub(",", "", str(num)) |
|
try: |
|
return float(num) |
|
except: |
|
if num.endswith("%"): |
|
num = num[:-1] |
|
if num.endswith("\\"): |
|
num = num[:-1] |
|
try: |
|
return float(num) / 100 |
|
except: |
|
pass |
|
return None |
|
|
|
|
|
def is_digit(num): |
|
|
|
return parse_digits(num) is not None |
|
|
|
|
|
def normalize_prediction(prediction): |
|
try: |
|
if is_digit(prediction): |
|
prediction = np.round(float(str(prediction).replace(",", "")), 6) |
|
return str(prediction) |
|
except: |
|
pass |
|
|
|
|
|
prediction = str(prediction).strip() |
|
|
|
|
|
brackets = [] |
|
while ( |
|
prediction.startswith("[") |
|
and prediction.endswith("]") |
|
or (prediction.startswith("(") and prediction.endswith(")")) |
|
): |
|
bracket = prediction[0] |
|
prediction = prediction[1:-1] |
|
if brackets and "," in prediction: |
|
pred_parts = [normalize_prediction(part) for part in prediction.split(",")] |
|
prediction = ",".join(pred_parts) |
|
|
|
if brackets: |
|
for b in reversed(brackets): |
|
if b == "[": |
|
prediction = "[" + prediction + "]" |
|
else: |
|
assert b == "(" |
|
prediction = "(" + prediction + ")" |
|
|
|
def _parse(s): |
|
for f in [parse_latex, parse_expr]: |
|
try: |
|
return f(s) |
|
except: |
|
pass |
|
return s |
|
|
|
prediction = _parse(prediction) |
|
|
|
for s in ["{", "}", "(", ")"]: |
|
prediction = prediction.replace(s, "") |
|
|
|
return prediction |
|
|
|
|
|
def math_equal( |
|
prediction: Union[bool, float, str], |
|
reference: Union[float, str], |
|
include_percentage: bool = True, |
|
is_close: bool = True, |
|
timeout: bool = False, |
|
) -> bool: |
|
""" |
|
Exact match of math if and only if: |
|
1. numerical equal: both can convert to float and are equal |
|
2. symbolic equal: both can convert to sympy expression and are equal |
|
""" |
|
if str(prediction) == str(reference): |
|
return True |
|
|
|
try: |
|
if is_digit(prediction) and is_digit(reference): |
|
prediction = parse_digits(prediction) |
|
reference = parse_digits(reference) |
|
|
|
if include_percentage: |
|
gt_result = [reference / 100, reference, reference * 100] |
|
else: |
|
gt_result = [reference] |
|
for item in gt_result: |
|
try: |
|
if is_close: |
|
if isclose(item, prediction, abs_tol=1e-3): |
|
return True |
|
else: |
|
if item == prediction: |
|
return True |
|
except Exception: |
|
continue |
|
return False |
|
except: |
|
pass |
|
|
|
if not prediction and prediction not in [0, False]: |
|
return False |
|
|
|
|
|
reference = str(reference).strip() |
|
prediction = str(prediction).strip() |
|
|
|
if ( |
|
regex.match(r"(\(|\[).+(\)|\])", prediction) is not None |
|
and regex.match(r"(\(|\[).+(\)|\])", reference) is not None |
|
): |
|
pred_parts = prediction[1:-1].split(",") |
|
ref_parts = reference[1:-1].split(",") |
|
if len(pred_parts) == len(ref_parts): |
|
if all( |
|
[ |
|
math_equal( |
|
pred_parts[i], ref_parts[i], include_percentage, is_close |
|
) |
|
for i in range(len(pred_parts)) |
|
] |
|
): |
|
return True |
|
|
|
if ( |
|
( |
|
prediction.startswith("\\begin{pmatrix}") |
|
or prediction.startswith("\\begin{bmatrix}") |
|
) |
|
and ( |
|
prediction.endswith("\\end{pmatrix}") |
|
or prediction.endswith("\\end{bmatrix}") |
|
) |
|
and ( |
|
reference.startswith("\\begin{pmatrix}") |
|
or reference.startswith("\\begin{bmatrix}") |
|
) |
|
and ( |
|
reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}") |
|
) |
|
): |
|
pred_lines = [ |
|
line.strip() |
|
for line in prediction[ |
|
len("\\begin{pmatrix}") : -len("\\end{pmatrix}") |
|
].split("\\\\") |
|
if line.strip() |
|
] |
|
ref_lines = [ |
|
line.strip() |
|
for line in reference[ |
|
len("\\begin{pmatrix}") : -len("\\end{pmatrix}") |
|
].split("\\\\") |
|
if line.strip() |
|
] |
|
matched = True |
|
if len(pred_lines) == len(ref_lines): |
|
for pred_line, ref_line in zip(pred_lines, ref_lines): |
|
pred_parts = pred_line.split("&") |
|
ref_parts = ref_line.split("&") |
|
if len(pred_parts) == len(ref_parts): |
|
if not all( |
|
[ |
|
math_equal( |
|
pred_parts[i], |
|
ref_parts[i], |
|
include_percentage, |
|
is_close, |
|
) |
|
for i in range(len(pred_parts)) |
|
] |
|
): |
|
matched = False |
|
break |
|
else: |
|
matched = False |
|
if not matched: |
|
break |
|
else: |
|
matched = False |
|
if matched: |
|
return True |
|
|
|
if prediction.count("=") == 1 and reference.count("=") == 1: |
|
pred = prediction.split("=") |
|
pred = f"{pred[0].strip()} - ({pred[1].strip()})" |
|
ref = reference.split("=") |
|
ref = f"{ref[0].strip()} - ({ref[1].strip()})" |
|
if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref): |
|
return True |
|
elif ( |
|
prediction.count("=") == 1 |
|
and len(prediction.split("=")[0].strip()) <= 2 |
|
and "=" not in reference |
|
): |
|
if math_equal( |
|
prediction.split("=")[1], reference, include_percentage, is_close |
|
): |
|
return True |
|
elif ( |
|
reference.count("=") == 1 |
|
and len(reference.split("=")[0].strip()) <= 2 |
|
and "=" not in prediction |
|
): |
|
if math_equal( |
|
prediction, reference.split("=")[1], include_percentage, is_close |
|
): |
|
return True |
|
|
|
|
|
if timeout: |
|
if call_with_timeout(symbolic_equal_process, prediction, reference): |
|
return True |
|
else: |
|
if symbolic_equal(prediction, reference): |
|
return True |
|
|
|
return False |
|
|
|
|
|
def math_equal_process(param): |
|
return math_equal(param[-2], param[-1]) |
|
|
|
|
|
def symbolic_equal(a, b): |
|
def _parse(s): |
|
for f in [parse_latex, parse_expr]: |
|
try: |
|
return f(s) |
|
except: |
|
pass |
|
return s |
|
|
|
a = _parse(a) |
|
b = _parse(b) |
|
|
|
try: |
|
if simplify(a - b) == 0: |
|
return True |
|
except: |
|
pass |
|
|
|
try: |
|
if isclose(N(a), N(b), abs_tol=1e-3): |
|
return True |
|
except: |
|
pass |
|
return False |
|
|
|
|
|
def symbolic_equal_process(a, b, output_queue): |
|
result = symbolic_equal(a, b) |
|
output_queue.put(result) |
|
|
|
|
|
def call_with_timeout(func, *args, timeout=1, **kwargs): |
|
output_queue = multiprocessing.Queue() |
|
process_args = args + (output_queue,) |
|
process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs) |
|
process.start() |
|
process.join(timeout) |
|
|
|
if process.is_alive(): |
|
process.terminate() |
|
process.join() |
|
return False |
|
|
|
return output_queue.get() |
|
|