|
import regex |
|
from copy import deepcopy |
|
from eval.eval_utils import math_equal |
|
from eval.ocwcourses_eval_utils import ( |
|
normalize_numeric, |
|
numeric_equality, |
|
normalize_symbolic_equation, |
|
SymbolicMathMixin, |
|
) |
|
|
|
|
|
def is_correct(item, pred_key="prediction", prec=1e-3): |
|
pred = item[pred_key] |
|
ans = item["answer"] |
|
if isinstance(pred, list) and isinstance(ans, list): |
|
pred_matched = set() |
|
ans_matched = set() |
|
for i in range(len(pred)): |
|
for j in range(len(ans)): |
|
item_cpy = deepcopy(item) |
|
item_cpy.update({pred_key: pred[i], "answer": ans[j]}) |
|
if is_correct(item_cpy, pred_key=pred_key, prec=prec): |
|
pred_matched.add(i) |
|
ans_matched.add(j) |
|
if item_cpy[pred_key] == "2,3,4": |
|
print(item, flush=True) |
|
print("wtf", flush=True) |
|
return len(pred_matched) == len(pred) and len(ans_matched) == len(ans) |
|
elif isinstance(pred, str) and isinstance(ans, str): |
|
if "\\cup" in pred and "\\cup" in ans: |
|
item = deepcopy(item) |
|
item.update( |
|
{ |
|
pred_key: pred.split("\\cup"), |
|
"answer": ans.split("\\cup"), |
|
} |
|
) |
|
return is_correct(item, pred_key=pred_key, prec=prec) |
|
else: |
|
label = False |
|
try: |
|
label = ( |
|
abs( |
|
float(regex.sub(r",", "", str(pred))) |
|
- float(regex.sub(r",", "", str(ans))) |
|
) |
|
< prec |
|
) |
|
except: |
|
pass |
|
label = label or (ans and pred == ans) or math_equal(pred, ans) |
|
return label |
|
else: |
|
print(item, flush=True) |
|
raise NotImplementedError() |
|
|
|
|
|
def eval_math(item, pred_key="prediction", prec=1e-3): |
|
pred = item[pred_key] |
|
if pred_key == "program_output" and isinstance(pred, str): |
|
pred = [pred] |
|
ans = item["answer"] |
|
if isinstance(pred, list) and isinstance(ans, list): |
|
|
|
_ans = [] |
|
for a in ans: |
|
if a not in _ans: |
|
_ans.append(a) |
|
ans = _ans |
|
|
|
_pred = [] |
|
for a in pred: |
|
if a not in _pred: |
|
_pred.append(a) |
|
|
|
pred = _pred[-len(ans) :] |
|
|
|
item.update({pred_key: pred, "answer": ans}) |
|
return is_correct(item, pred_key=pred_key, prec=prec) |
|
|
|
|
|
def eval_last_single_answer(item, pred_key="prediction", prec=1e-3): |
|
for key in [pred_key, "answer"]: |
|
assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str" |
|
return is_correct(item, pred_key=pred_key, prec=prec) |
|
|
|
|
|
def eval_agieval_gaokao_math_cloze(item, pred_key="prediction", prec=1e-3): |
|
if pred_key == "program_output" and isinstance(item[pred_key], str): |
|
item[pred_key] = [item[pred_key]] |
|
for key in [pred_key, "answer"]: |
|
assert isinstance(item[key], list), f"{key} = `{item[key]}` is not a list" |
|
pred = item[pred_key] |
|
ans = item["answer"] |
|
_pred = [] |
|
for p in pred: |
|
p = p + ";" |
|
while p: |
|
left_brackets = 0 |
|
for i in range(len(p)): |
|
if p[i] == ";" or (p[i] == "," and left_brackets == 0): |
|
_p, p = p[:i].strip(), p[i + 1 :].strip() |
|
if _p not in _pred: |
|
_pred.append(_p) |
|
break |
|
elif p[i] in "([{": |
|
left_brackets += 1 |
|
elif p[i] in ")]}": |
|
left_brackets -= 1 |
|
pred = _pred[-len(ans) :] |
|
if len(pred) == len(ans): |
|
for p, a in zip(pred, ans): |
|
item.update( |
|
{ |
|
pred_key: p, |
|
"answer": a, |
|
} |
|
) |
|
if not is_correct(item, pred_key=pred_key, prec=prec): |
|
return False |
|
return True |
|
else: |
|
return False |
|
|
|
|
|
def eval_agieval_gaokao_mathqa(item, pred_key="prediction", prec=1e-3): |
|
if pred_key == "program_output" and isinstance(item[pred_key], str): |
|
item[pred_key] = [item[pred_key]] |
|
pred_str = " ".join(item[pred_key]) |
|
ans = item["answer"] |
|
tag = None |
|
idx = -1 |
|
for t in "ABCD": |
|
if t in pred_str and pred_str.index(t) > idx: |
|
tag = t |
|
idx = pred_str.index(t) |
|
return tag == ans |
|
|
|
|
|
def eval_math_sat(item, pred_key="prediction", prec=1e-3): |
|
for key in [pred_key, "answer"]: |
|
assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str" |
|
return item[pred_key].lower() == item["answer"].lower() |
|
|
|
|
|
def eval_mmlu_stem(item, pred_key="prediction", prec=1e-3): |
|
return eval_math_sat(item, pred_key=pred_key, prec=prec) |
|
|
|
|
|
def eval_ocwcourses(item, pred_key="prediction", prec=1e-3): |
|
INVALID_ANSWER = "[invalidanswer]" |
|
for key in [pred_key, "answer"]: |
|
assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str" |
|
pred = item[pred_key] |
|
ans = item["answer"] |
|
|
|
try: |
|
float(ans) |
|
normalize_fn = normalize_numeric |
|
is_equiv = numeric_equality |
|
answer_type = "numeric" |
|
except ValueError: |
|
if "=" in ans: |
|
normalize_fn = normalize_symbolic_equation |
|
is_equiv = lambda x, y: x == y |
|
answer_type = "equation" |
|
else: |
|
normalize_fn = SymbolicMathMixin().normalize_tex |
|
is_equiv = SymbolicMathMixin().is_tex_equiv |
|
answer_type = "expression" |
|
|
|
correct_answer = normalize_fn(ans) |
|
|
|
unnormalized_answer = pred if pred else INVALID_ANSWER |
|
model_answer = normalize_fn(unnormalized_answer) |
|
|
|
if unnormalized_answer == INVALID_ANSWER: |
|
acc = 0 |
|
elif model_answer == INVALID_ANSWER: |
|
acc = 0 |
|
elif is_equiv(model_answer, correct_answer): |
|
acc = 1 |
|
else: |
|
acc = 0 |
|
|
|
return acc |
|
|
|
|
|
def eval_minif2f_isabelle(item, pred_key="prediction", prec=1e-3): |
|
return True |
|
|