ScientificArgumentRecommender / src /metrics /f1_with_bootstrapping.py
ArneBinder's picture
update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e verified
from collections import defaultdict
from functools import partial
from typing import Callable, Hashable, Optional, Tuple, Dict, Collection, List, Set
from pie_modules.metrics import F1Metric
from pytorch_ie import Annotation, Document
def has_one_of_the_labels(ann: Annotation, label_field: str, labels: Collection[str]) -> bool:
return getattr(ann, label_field) in labels
def has_this_label(ann: Annotation, label_field: str, label: str) -> bool:
return getattr(ann, label_field) == label
class F1WithBootstrappingMetric(F1Metric):
def __init__(self, *args, bootstrap_n: int = 0, **kwargs):
super().__init__(*args, **kwargs)
self.bootstrap_n = bootstrap_n
def reset(self) -> None:
self.tp: Dict[str, Set[Annotation]] = defaultdict(set)
self.fp: Dict[str, Set[Annotation]] = defaultdict(set)
self.fn: Dict[str, Set[Annotation]] = defaultdict(set)
def calculate_tp_fp_fn(
self,
document: Document,
annotation_filter: Optional[Callable[[Annotation], bool]] = None,
annotation_processor: Optional[Callable[[Annotation], Hashable]] = None,
) -> Tuple[Set[Annotation], Set[Annotation], Set[Annotation]]:
annotation_processor = annotation_processor or (lambda ann: ann)
annotation_filter = annotation_filter or (lambda ann: True)
predicted_annotations = {
annotation_processor(ann)
for ann in document[self.layer].predictions
if annotation_filter(ann)
}
gold_annotations = {
annotation_processor(ann) for ann in document[self.layer] if annotation_filter(ann)
}
return predicted_annotations & gold_annotations, predicted_annotations - gold_annotations, gold_annotations - predicted_annotations
def add_tp_fp_fn(self, tp: Set[Annotation], fp: Set[Annotation], fn: Set[Annotation], label: str) -> None:
self.tp[label].update(tp)
self.fp[label].update(fp)
self.fn[label].update(fn)
def _update(self, document: Document) -> None:
new_values = self.calculate_tp_fp_fn(
document=document,
annotation_filter=(
partial(has_one_of_the_labels, label_field=self.label_field, labels=self.labels)
if self.per_label and not self.infer_labels
else None
),
annotation_processor=self.annotation_processor,
)
self.add_tp_fp_fn(*new_values, label="MICRO")
if self.infer_labels:
layer = document[self.layer]
# collect labels from gold data and predictions
for ann in list(layer) + list(layer.predictions):
label = getattr(ann, self.label_field)
if label not in self.labels:
self.labels.append(label)
if self.per_label:
for label in self.labels:
new_values = self.calculate_tp_fp_fn(
document=document,
annotation_filter=partial(
has_this_label, label_field=self.label_field, label=label
),
annotation_processor=self.annotation_processor,
)
self.add_tp_fp_fn(*new_values, label=label)
def _compute(self) -> Dict[str, Dict[str, float]]:
res = dict()
if self.per_label:
res["MACRO"] = {"f1": 0.0, "p": 0.0, "r": 0.0}
for label in self.tp.keys():
tp, fp, fn = (
len(self.tp[label]),
len(self.fp[label]),
len(self.fn[label]),
)
if tp == 0:
p, r, f1 = 0.0, 0.0, 0.0
else:
p = tp / (tp + fp)
r = tp / (tp + fn)
f1 = 2 * p * r / (p + r)
res[label] = {"f1": f1, "p": p, "r": r, "s": tp + fn}
if self.per_label and label in self.labels:
res["MACRO"]["f1"] += f1 / len(self.labels)
res["MACRO"]["p"] += p / len(self.labels)
res["MACRO"]["r"] += r / len(self.labels)
if self.show_as_markdown:
logger.info(f"\n{self.layer}:\n{pd.DataFrame(res).round(3).T.to_markdown()}")
return res