File size: 4,276 Bytes
d868d2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from collections import defaultdict
from functools import partial
from typing import Callable, Hashable, Optional, Tuple, Dict, Collection, List, Set

from pie_modules.metrics import F1Metric
from pytorch_ie import Annotation, Document

def has_one_of_the_labels(ann: Annotation, label_field: str, labels: Collection[str]) -> bool:
    return getattr(ann, label_field) in labels


def has_this_label(ann: Annotation, label_field: str, label: str) -> bool:
    return getattr(ann, label_field) == label


class F1WithBootstrappingMetric(F1Metric):
    def __init__(self, *args, bootstrap_n: int = 0, **kwargs):
        super().__init__(*args, **kwargs)
        self.bootstrap_n = bootstrap_n


    def reset(self) -> None:
        self.tp: Dict[str, Set[Annotation]] = defaultdict(set)
        self.fp: Dict[str, Set[Annotation]] = defaultdict(set)
        self.fn: Dict[str, Set[Annotation]] = defaultdict(set)

    def calculate_tp_fp_fn(
        self,
        document: Document,
        annotation_filter: Optional[Callable[[Annotation], bool]] = None,
        annotation_processor: Optional[Callable[[Annotation], Hashable]] = None,
    ) -> Tuple[Set[Annotation], Set[Annotation], Set[Annotation]]:
        annotation_processor = annotation_processor or (lambda ann: ann)
        annotation_filter = annotation_filter or (lambda ann: True)
        predicted_annotations = {
            annotation_processor(ann)
            for ann in document[self.layer].predictions
            if annotation_filter(ann)
        }
        gold_annotations = {
            annotation_processor(ann) for ann in document[self.layer] if annotation_filter(ann)
        }
        return predicted_annotations & gold_annotations, predicted_annotations - gold_annotations, gold_annotations - predicted_annotations


    def add_tp_fp_fn(self, tp: Set[Annotation], fp: Set[Annotation], fn: Set[Annotation], label: str) -> None:
        self.tp[label].update(tp)
        self.fp[label].update(fp)
        self.fn[label].update(fn)

    def _update(self, document: Document) -> None:
        new_values = self.calculate_tp_fp_fn(
            document=document,
            annotation_filter=(
                partial(has_one_of_the_labels, label_field=self.label_field, labels=self.labels)
                if self.per_label and not self.infer_labels
                else None
            ),
            annotation_processor=self.annotation_processor,
        )
        self.add_tp_fp_fn(*new_values, label="MICRO")
        if self.infer_labels:
            layer = document[self.layer]
            # collect labels from gold data and predictions
            for ann in list(layer) + list(layer.predictions):
                label = getattr(ann, self.label_field)
                if label not in self.labels:
                    self.labels.append(label)
        if self.per_label:
            for label in self.labels:
                new_values = self.calculate_tp_fp_fn(
                    document=document,
                    annotation_filter=partial(
                        has_this_label, label_field=self.label_field, label=label
                    ),
                    annotation_processor=self.annotation_processor,
                )
                self.add_tp_fp_fn(*new_values, label=label)

    def _compute(self) -> Dict[str, Dict[str, float]]:
        res = dict()
        if self.per_label:
            res["MACRO"] = {"f1": 0.0, "p": 0.0, "r": 0.0}
        for label in self.tp.keys():
            tp, fp, fn = (
                len(self.tp[label]),
                len(self.fp[label]),
                len(self.fn[label]),
            )
            if tp == 0:
                p, r, f1 = 0.0, 0.0, 0.0
            else:
                p = tp / (tp + fp)
                r = tp / (tp + fn)
                f1 = 2 * p * r / (p + r)
            res[label] = {"f1": f1, "p": p, "r": r, "s": tp + fn}
            if self.per_label and label in self.labels:
                res["MACRO"]["f1"] += f1 / len(self.labels)
                res["MACRO"]["p"] += p / len(self.labels)
                res["MACRO"]["r"] += r / len(self.labels)
        if self.show_as_markdown:
            logger.info(f"\n{self.layer}:\n{pd.DataFrame(res).round(3).T.to_markdown()}")
        return res