File size: 4,037 Bytes
9d5b280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import ast
import re
import unicodedata as ud


def clean_answer(answer: str):
    # remove whitespace and final stop
    clean = answer.strip().strip(".")

    # reduce multiple spaces to a single space
    clean = re.sub(r"[ ]+", " ", clean)

    # reduce to lower case
    clean = clean.lower()

    # remove internal + (can't currently handle for marking)
    clean = re.sub("\\+", "", clean)

    # make quotes consistent
    quotes_map = {"‘": "'", "’": "'", "“": '"', "”": '"'}

    for k, v in quotes_map.items():
        clean = re.sub(k, v, clean)

    # make unicode consistent
    clean = ud.normalize("NFKD", clean)

    return clean


def safe_exact(references: list[str], predictions: list[str]):
    if len(references[0]) == 0:
        return 1.0
    if len(predictions[0]) == 0:
        return 0.0

    score = float(references[0] == predictions[0])

    return score


def parse_str_list_score(model, correct, scoring_func):
    model = str(model)
    if len(correct) == 0:
        return 1.0
    if len(model) == 0:
        return 0.0
    if ("[" in correct) and (("'" in correct) or ('"' in correct)):
        readstr = ast.literal_eval(correct)
        if isinstance(readstr, list):
            correct = readstr
    if isinstance(correct, list):
        if all(isinstance(c, str) for c in correct):
            max_score = 0.0
            if (
                len(correct) > 24
            ):  # bleu and rouge are expensive and don't make sense for any order problems
                return clean_answer(model) in [clean_answer(c) for c in correct]
            for c in correct:
                score = scoring_func(
                    references=[clean_answer(c)],
                    predictions=[clean_answer(model)],
                )
                if score > max_score:
                    max_score = score
            return max_score
        else:
            max_score = 0.0
            for c in correct:
                if isinstance(c, list):
                    c = ", ".join(c)
                    score = scoring_func(
                        references=[clean_answer(c)],
                        predictions=[clean_answer(model)],
                    )
                else:
                    score = scoring_func(
                        references=[clean_answer(c)],
                        predictions=[clean_answer(model)],
                    )
                if score > max_score:
                    max_score = score
            return max_score
    else:
        return scoring_func(
            references=[clean_answer(correct)],
            predictions=[clean_answer(model)],
        )


def exact_match(references: list[str], predictions: list[str]):
    ref_dict = ast.literal_eval(references[0])
    try:
        assert "{" in predictions[0]
        if predictions[0][-1] == "}":
            pred_dict = ast.literal_eval(predictions[0][predictions[0].index("{") :])
        else:
            pred_dict = ast.literal_eval(
                predictions[0][predictions[0].index("{") :] + "}"
            )
    except (SyntaxError, ValueError, AssertionError):
        pred_dict = {}
        for k in ref_dict.keys():
            m = re.search(re.escape(str(k)) + """': ([^']+)'[,\\}]""", predictions[0])
            n = re.search(re.escape(str(k)) + """": ([^"]+)"[,\\}]""", predictions[0])
            if m:
                pred_dict[k] = m.group()[:-1]
            elif n:
                pred_dict[k] = n.group()[:-1]
            else:
                pred_dict[k] = ""
    pred_dict_full = {
        k: pred_dict[k] if k in pred_dict else "" for k in ref_dict.keys()
    }

    scores = [
        parse_str_list_score(pred_dict_full[k], v, safe_exact)
        for k, v in ref_dict.items()
    ]

    return scores


def aggregate_scores(input):
    return sum([sum(i) for i in input]) / sum([len(j) for j in input])


def aggregate_metrics(
    metrics_scores: list[int], dataset_size: list[int], weight_by_size: bool
):
    return metrics_scores[0] - metrics_scores[1]