update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
from collections import defaultdict | |
from functools import partial | |
from typing import Callable, Hashable, Optional, Tuple, Dict, Collection, List, Set | |
from pie_modules.metrics import F1Metric | |
from pytorch_ie import Annotation, Document | |
def has_one_of_the_labels(ann: Annotation, label_field: str, labels: Collection[str]) -> bool: | |
return getattr(ann, label_field) in labels | |
def has_this_label(ann: Annotation, label_field: str, label: str) -> bool: | |
return getattr(ann, label_field) == label | |
class F1WithBootstrappingMetric(F1Metric): | |
def __init__(self, *args, bootstrap_n: int = 0, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.bootstrap_n = bootstrap_n | |
def reset(self) -> None: | |
self.tp: Dict[str, Set[Annotation]] = defaultdict(set) | |
self.fp: Dict[str, Set[Annotation]] = defaultdict(set) | |
self.fn: Dict[str, Set[Annotation]] = defaultdict(set) | |
def calculate_tp_fp_fn( | |
self, | |
document: Document, | |
annotation_filter: Optional[Callable[[Annotation], bool]] = None, | |
annotation_processor: Optional[Callable[[Annotation], Hashable]] = None, | |
) -> Tuple[Set[Annotation], Set[Annotation], Set[Annotation]]: | |
annotation_processor = annotation_processor or (lambda ann: ann) | |
annotation_filter = annotation_filter or (lambda ann: True) | |
predicted_annotations = { | |
annotation_processor(ann) | |
for ann in document[self.layer].predictions | |
if annotation_filter(ann) | |
} | |
gold_annotations = { | |
annotation_processor(ann) for ann in document[self.layer] if annotation_filter(ann) | |
} | |
return predicted_annotations & gold_annotations, predicted_annotations - gold_annotations, gold_annotations - predicted_annotations | |
def add_tp_fp_fn(self, tp: Set[Annotation], fp: Set[Annotation], fn: Set[Annotation], label: str) -> None: | |
self.tp[label].update(tp) | |
self.fp[label].update(fp) | |
self.fn[label].update(fn) | |
def _update(self, document: Document) -> None: | |
new_values = self.calculate_tp_fp_fn( | |
document=document, | |
annotation_filter=( | |
partial(has_one_of_the_labels, label_field=self.label_field, labels=self.labels) | |
if self.per_label and not self.infer_labels | |
else None | |
), | |
annotation_processor=self.annotation_processor, | |
) | |
self.add_tp_fp_fn(*new_values, label="MICRO") | |
if self.infer_labels: | |
layer = document[self.layer] | |
# collect labels from gold data and predictions | |
for ann in list(layer) + list(layer.predictions): | |
label = getattr(ann, self.label_field) | |
if label not in self.labels: | |
self.labels.append(label) | |
if self.per_label: | |
for label in self.labels: | |
new_values = self.calculate_tp_fp_fn( | |
document=document, | |
annotation_filter=partial( | |
has_this_label, label_field=self.label_field, label=label | |
), | |
annotation_processor=self.annotation_processor, | |
) | |
self.add_tp_fp_fn(*new_values, label=label) | |
def _compute(self) -> Dict[str, Dict[str, float]]: | |
res = dict() | |
if self.per_label: | |
res["MACRO"] = {"f1": 0.0, "p": 0.0, "r": 0.0} | |
for label in self.tp.keys(): | |
tp, fp, fn = ( | |
len(self.tp[label]), | |
len(self.fp[label]), | |
len(self.fn[label]), | |
) | |
if tp == 0: | |
p, r, f1 = 0.0, 0.0, 0.0 | |
else: | |
p = tp / (tp + fp) | |
r = tp / (tp + fn) | |
f1 = 2 * p * r / (p + r) | |
res[label] = {"f1": f1, "p": p, "r": r, "s": tp + fn} | |
if self.per_label and label in self.labels: | |
res["MACRO"]["f1"] += f1 / len(self.labels) | |
res["MACRO"]["p"] += p / len(self.labels) | |
res["MACRO"]["r"] += r / len(self.labels) | |
if self.show_as_markdown: | |
logger.info(f"\n{self.layer}:\n{pd.DataFrame(res).round(3).T.to_markdown()}") | |
return res | |