update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
from typing import Callable, Hashable, Optional, Tuple | |
from pie_modules.metrics import F1Metric | |
from pytorch_ie import Annotation, Document | |
class F1WithThresholdMetric(F1Metric): | |
def __init__(self, *args, threshold: float = 0.0, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.threshold = threshold | |
def calculate_counts( | |
self, | |
document: Document, | |
annotation_filter: Optional[Callable[[Annotation], bool]] = None, | |
annotation_processor: Optional[Callable[[Annotation], Hashable]] = None, | |
) -> Tuple[int, int, int]: | |
annotation_processor = annotation_processor or (lambda ann: ann) | |
annotation_filter = annotation_filter or (lambda ann: True) | |
predicted_annotations = { | |
annotation_processor(ann) | |
for ann in document[self.layer].predictions | |
if annotation_filter(ann) and getattr(ann, "score", 0.0) >= self.threshold | |
} | |
gold_annotations = { | |
annotation_processor(ann) | |
for ann in document[self.layer] | |
if annotation_filter(ann) and getattr(ann, "score", 0.0) >= self.threshold | |
} | |
tp = len([ann for ann in predicted_annotations & gold_annotations]) | |
fn = len([ann for ann in gold_annotations - predicted_annotations]) | |
fp = len([ann for ann in predicted_annotations - gold_annotations]) | |
return tp, fp, fn | |