ArneBinder's picture
update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e verified
from typing import Callable, Hashable, Optional, Tuple
from pie_modules.metrics import F1Metric
from pytorch_ie import Annotation, Document
class F1WithThresholdMetric(F1Metric):
def __init__(self, *args, threshold: float = 0.0, **kwargs):
super().__init__(*args, **kwargs)
self.threshold = threshold
def calculate_counts(
self,
document: Document,
annotation_filter: Optional[Callable[[Annotation], bool]] = None,
annotation_processor: Optional[Callable[[Annotation], Hashable]] = None,
) -> Tuple[int, int, int]:
annotation_processor = annotation_processor or (lambda ann: ann)
annotation_filter = annotation_filter or (lambda ann: True)
predicted_annotations = {
annotation_processor(ann)
for ann in document[self.layer].predictions
if annotation_filter(ann) and getattr(ann, "score", 0.0) >= self.threshold
}
gold_annotations = {
annotation_processor(ann)
for ann in document[self.layer]
if annotation_filter(ann) and getattr(ann, "score", 0.0) >= self.threshold
}
tp = len([ann for ann in predicted_annotations & gold_annotations])
fn = len([ann for ann in gold_annotations - predicted_annotations])
fp = len([ann for ann in predicted_annotations - gold_annotations])
return tp, fp, fn