from collections import defaultdict from functools import partial from typing import Callable, Hashable, Optional, Tuple, Dict, Collection, List, Set from pie_modules.metrics import F1Metric from pytorch_ie import Annotation, Document def has_one_of_the_labels(ann: Annotation, label_field: str, labels: Collection[str]) -> bool: return getattr(ann, label_field) in labels def has_this_label(ann: Annotation, label_field: str, label: str) -> bool: return getattr(ann, label_field) == label class F1WithBootstrappingMetric(F1Metric): def __init__(self, *args, bootstrap_n: int = 0, **kwargs): super().__init__(*args, **kwargs) self.bootstrap_n = bootstrap_n def reset(self) -> None: self.tp: Dict[str, Set[Annotation]] = defaultdict(set) self.fp: Dict[str, Set[Annotation]] = defaultdict(set) self.fn: Dict[str, Set[Annotation]] = defaultdict(set) def calculate_tp_fp_fn( self, document: Document, annotation_filter: Optional[Callable[[Annotation], bool]] = None, annotation_processor: Optional[Callable[[Annotation], Hashable]] = None, ) -> Tuple[Set[Annotation], Set[Annotation], Set[Annotation]]: annotation_processor = annotation_processor or (lambda ann: ann) annotation_filter = annotation_filter or (lambda ann: True) predicted_annotations = { annotation_processor(ann) for ann in document[self.layer].predictions if annotation_filter(ann) } gold_annotations = { annotation_processor(ann) for ann in document[self.layer] if annotation_filter(ann) } return predicted_annotations & gold_annotations, predicted_annotations - gold_annotations, gold_annotations - predicted_annotations def add_tp_fp_fn(self, tp: Set[Annotation], fp: Set[Annotation], fn: Set[Annotation], label: str) -> None: self.tp[label].update(tp) self.fp[label].update(fp) self.fn[label].update(fn) def _update(self, document: Document) -> None: new_values = self.calculate_tp_fp_fn( document=document, annotation_filter=( partial(has_one_of_the_labels, label_field=self.label_field, labels=self.labels) if self.per_label and not self.infer_labels else None ), annotation_processor=self.annotation_processor, ) self.add_tp_fp_fn(*new_values, label="MICRO") if self.infer_labels: layer = document[self.layer] # collect labels from gold data and predictions for ann in list(layer) + list(layer.predictions): label = getattr(ann, self.label_field) if label not in self.labels: self.labels.append(label) if self.per_label: for label in self.labels: new_values = self.calculate_tp_fp_fn( document=document, annotation_filter=partial( has_this_label, label_field=self.label_field, label=label ), annotation_processor=self.annotation_processor, ) self.add_tp_fp_fn(*new_values, label=label) def _compute(self) -> Dict[str, Dict[str, float]]: res = dict() if self.per_label: res["MACRO"] = {"f1": 0.0, "p": 0.0, "r": 0.0} for label in self.tp.keys(): tp, fp, fn = ( len(self.tp[label]), len(self.fp[label]), len(self.fn[label]), ) if tp == 0: p, r, f1 = 0.0, 0.0, 0.0 else: p = tp / (tp + fp) r = tp / (tp + fn) f1 = 2 * p * r / (p + r) res[label] = {"f1": f1, "p": p, "r": r, "s": tp + fn} if self.per_label and label in self.labels: res["MACRO"]["f1"] += f1 / len(self.labels) res["MACRO"]["p"] += p / len(self.labels) res["MACRO"]["r"] += r / len(self.labels) if self.show_as_markdown: logger.info(f"\n{self.layer}:\n{pd.DataFrame(res).round(3).T.to_markdown()}") return res