File size: 4,037 Bytes
9d5b280 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import ast
import re
import unicodedata as ud
def clean_answer(answer: str):
# remove whitespace and final stop
clean = answer.strip().strip(".")
# reduce multiple spaces to a single space
clean = re.sub(r"[ ]+", " ", clean)
# reduce to lower case
clean = clean.lower()
# remove internal + (can't currently handle for marking)
clean = re.sub("\\+", "", clean)
# make quotes consistent
quotes_map = {"‘": "'", "’": "'", "“": '"', "”": '"'}
for k, v in quotes_map.items():
clean = re.sub(k, v, clean)
# make unicode consistent
clean = ud.normalize("NFKD", clean)
return clean
def safe_exact(references: list[str], predictions: list[str]):
if len(references[0]) == 0:
return 1.0
if len(predictions[0]) == 0:
return 0.0
score = float(references[0] == predictions[0])
return score
def parse_str_list_score(model, correct, scoring_func):
model = str(model)
if len(correct) == 0:
return 1.0
if len(model) == 0:
return 0.0
if ("[" in correct) and (("'" in correct) or ('"' in correct)):
readstr = ast.literal_eval(correct)
if isinstance(readstr, list):
correct = readstr
if isinstance(correct, list):
if all(isinstance(c, str) for c in correct):
max_score = 0.0
if (
len(correct) > 24
): # bleu and rouge are expensive and don't make sense for any order problems
return clean_answer(model) in [clean_answer(c) for c in correct]
for c in correct:
score = scoring_func(
references=[clean_answer(c)],
predictions=[clean_answer(model)],
)
if score > max_score:
max_score = score
return max_score
else:
max_score = 0.0
for c in correct:
if isinstance(c, list):
c = ", ".join(c)
score = scoring_func(
references=[clean_answer(c)],
predictions=[clean_answer(model)],
)
else:
score = scoring_func(
references=[clean_answer(c)],
predictions=[clean_answer(model)],
)
if score > max_score:
max_score = score
return max_score
else:
return scoring_func(
references=[clean_answer(correct)],
predictions=[clean_answer(model)],
)
def exact_match(references: list[str], predictions: list[str]):
ref_dict = ast.literal_eval(references[0])
try:
assert "{" in predictions[0]
if predictions[0][-1] == "}":
pred_dict = ast.literal_eval(predictions[0][predictions[0].index("{") :])
else:
pred_dict = ast.literal_eval(
predictions[0][predictions[0].index("{") :] + "}"
)
except (SyntaxError, ValueError, AssertionError):
pred_dict = {}
for k in ref_dict.keys():
m = re.search(re.escape(str(k)) + """': ([^']+)'[,\\}]""", predictions[0])
n = re.search(re.escape(str(k)) + """": ([^"]+)"[,\\}]""", predictions[0])
if m:
pred_dict[k] = m.group()[:-1]
elif n:
pred_dict[k] = n.group()[:-1]
else:
pred_dict[k] = ""
pred_dict_full = {
k: pred_dict[k] if k in pred_dict else "" for k in ref_dict.keys()
}
scores = [
parse_str_list_score(pred_dict_full[k], v, safe_exact)
for k, v in ref_dict.items()
]
return scores
def aggregate_scores(input):
return sum([sum(i) for i in input]) / sum([len(j) for j in input])
def aggregate_metrics(
metrics_scores: list[int], dataset_size: list[int], weight_by_size: bool
):
return metrics_scores[0] - metrics_scores[1]
|