from typing import Callable, Hashable, Optional, Tuple from pie_modules.metrics import F1Metric from pytorch_ie import Annotation, Document class F1WithThresholdMetric(F1Metric): def __init__(self, *args, threshold: float = 0.0, **kwargs): super().__init__(*args, **kwargs) self.threshold = threshold def calculate_counts( self, document: Document, annotation_filter: Optional[Callable[[Annotation], bool]] = None, annotation_processor: Optional[Callable[[Annotation], Hashable]] = None, ) -> Tuple[int, int, int]: annotation_processor = annotation_processor or (lambda ann: ann) annotation_filter = annotation_filter or (lambda ann: True) predicted_annotations = { annotation_processor(ann) for ann in document[self.layer].predictions if annotation_filter(ann) and getattr(ann, "score", 0.0) >= self.threshold } gold_annotations = { annotation_processor(ann) for ann in document[self.layer] if annotation_filter(ann) and getattr(ann, "score", 0.0) >= self.threshold } tp = len([ann for ann in predicted_annotations & gold_annotations]) fn = len([ann for ann in gold_annotations - predicted_annotations]) fp = len([ann for ann in predicted_annotations - gold_annotations]) return tp, fp, fn