|
import ast |
|
import re |
|
import unicodedata as ud |
|
|
|
|
|
def clean_answer(answer: str): |
|
|
|
clean = answer.strip().strip(".") |
|
|
|
|
|
clean = re.sub(r"[ ]+", " ", clean) |
|
|
|
|
|
clean = clean.lower() |
|
|
|
|
|
clean = re.sub("\\+", "", clean) |
|
|
|
|
|
quotes_map = {"β": "'", "β": "'", "β": '"', "β": '"'} |
|
|
|
for k, v in quotes_map.items(): |
|
clean = re.sub(k, v, clean) |
|
|
|
|
|
clean = ud.normalize("NFKD", clean) |
|
|
|
return clean |
|
|
|
|
|
def safe_exact(references: list[str], predictions: list[str]): |
|
if len(references[0]) == 0: |
|
return 1.0 |
|
if len(predictions[0]) == 0: |
|
return 0.0 |
|
|
|
score = float(references[0] == predictions[0]) |
|
|
|
return score |
|
|
|
|
|
def parse_str_list_score(model, correct, scoring_func): |
|
model = str(model) |
|
if len(correct) == 0: |
|
return 1.0 |
|
if len(model) == 0: |
|
return 0.0 |
|
if ("[" in correct) and (("'" in correct) or ('"' in correct)): |
|
readstr = ast.literal_eval(correct) |
|
if isinstance(readstr, list): |
|
correct = readstr |
|
if isinstance(correct, list): |
|
if all(isinstance(c, str) for c in correct): |
|
max_score = 0.0 |
|
if ( |
|
len(correct) > 24 |
|
): |
|
return clean_answer(model) in [clean_answer(c) for c in correct] |
|
for c in correct: |
|
score = scoring_func( |
|
references=[clean_answer(c)], |
|
predictions=[clean_answer(model)], |
|
) |
|
if score > max_score: |
|
max_score = score |
|
return max_score |
|
else: |
|
max_score = 0.0 |
|
for c in correct: |
|
if isinstance(c, list): |
|
c = ", ".join(c) |
|
score = scoring_func( |
|
references=[clean_answer(c)], |
|
predictions=[clean_answer(model)], |
|
) |
|
else: |
|
score = scoring_func( |
|
references=[clean_answer(c)], |
|
predictions=[clean_answer(model)], |
|
) |
|
if score > max_score: |
|
max_score = score |
|
return max_score |
|
else: |
|
return scoring_func( |
|
references=[clean_answer(correct)], |
|
predictions=[clean_answer(model)], |
|
) |
|
|
|
|
|
def exact_match(references: list[str], predictions: list[str]): |
|
ref_dict = ast.literal_eval(references[0]) |
|
try: |
|
assert "{" in predictions[0] |
|
if predictions[0][-1] == "}": |
|
pred_dict = ast.literal_eval(predictions[0][predictions[0].index("{") :]) |
|
else: |
|
pred_dict = ast.literal_eval( |
|
predictions[0][predictions[0].index("{") :] + "}" |
|
) |
|
except (SyntaxError, ValueError, AssertionError): |
|
pred_dict = {} |
|
for k in ref_dict.keys(): |
|
m = re.search(re.escape(str(k)) + """': ([^']+)'[,\\}]""", predictions[0]) |
|
n = re.search(re.escape(str(k)) + """": ([^"]+)"[,\\}]""", predictions[0]) |
|
if m: |
|
pred_dict[k] = m.group()[:-1] |
|
elif n: |
|
pred_dict[k] = n.group()[:-1] |
|
else: |
|
pred_dict[k] = "" |
|
pred_dict_full = { |
|
k: pred_dict[k] if k in pred_dict else "" for k in ref_dict.keys() |
|
} |
|
|
|
scores = [ |
|
parse_str_list_score(pred_dict_full[k], v, safe_exact) |
|
for k, v in ref_dict.items() |
|
] |
|
|
|
return scores |
|
|
|
|
|
def aggregate_scores(input): |
|
return sum([sum(i) for i in input]) / sum([len(j) for j in input]) |
|
|
|
|
|
def aggregate_metrics( |
|
metrics_scores: list[int], dataset_size: list[int], weight_by_size: bool |
|
): |
|
return metrics_scores[0] - metrics_scores[1] |
|
|