import logging from collections import defaultdict from functools import partial from typing import ( Any, Callable, Collection, Dict, Hashable, List, Optional, Tuple, TypeAlias, Union, ) from pytorch_ie.core import Annotation, Document, DocumentMetric from pytorch_ie.utils.hydra import resolve_target from src.document.types import RelatedRelation logger = logging.getLogger(__name__) def has_one_of_the_labels(ann: Annotation, label_field: str, labels: Collection[str]) -> bool: return getattr(ann, label_field) in labels def has_this_label(ann: Annotation, label_field: str, label: str) -> bool: return getattr(ann, label_field) == label InstanceType: TypeAlias = Tuple[Document, Annotation] InstancesType: TypeAlias = Tuple[List[InstanceType], List[InstanceType], List[InstanceType]] class TPFFPFNMetric(DocumentMetric): """Computes the lists of True Positive, False Positive, and False Negative annotations for a given layer. If labels are provided, it also computes the counts for each label separately. Works only with `RelatedRelation` annotations for now. Args: layer: The layer to compute the metrics for. labels: If provided, calculate metrics for each label. label_field: The field to use for the label. Defaults to "label". """ def __init__( self, layer: str, labels: Optional[Union[Collection[str], str]] = None, label_field: str = "label", annotation_processor: Optional[Union[Callable[[Annotation], Hashable], str]] = None, ): super().__init__() self.layer = layer self.label_field = label_field self.annotation_processor: Optional[Callable[[Annotation], Hashable]] if isinstance(annotation_processor, str): self.annotation_processor = resolve_target(annotation_processor) else: self.annotation_processor = annotation_processor self.per_label = labels is not None self.infer_labels = False if self.per_label: if isinstance(labels, str): if labels != "INFERRED": raise ValueError( "labels can only be 'INFERRED' if per_label is True and labels is a string" ) self.labels = [] self.infer_labels = True elif isinstance(labels, Collection): if not all(isinstance(label, str) for label in labels): raise ValueError("labels must be a collection of strings") if "MICRO" in labels or "MACRO" in labels: raise ValueError( "labels cannot contain 'MICRO' or 'MACRO' because they are used to capture aggregated metrics" ) if len(labels) == 0: raise ValueError("labels cannot be empty") self.labels = list(labels) else: raise ValueError("labels must be a string or a collection of strings") def reset(self): self.tp_fp_fn = defaultdict(lambda: (list(), list(), list())) def get_tp_fp_fn( self, document: Document, annotation_filter: Optional[Callable[[Annotation], bool]] = None, annotation_processor: Optional[Callable[[Annotation], Hashable]] = None, ) -> InstancesType: annotation_processor = annotation_processor or (lambda ann: ann) annotation_filter = annotation_filter or (lambda ann: True) predicted_annotations = { annotation_processor(ann) for ann in document[self.layer].predictions if annotation_filter(ann) } gold_annotations = { annotation_processor(ann) for ann in document[self.layer] if annotation_filter(ann) } tp = [(document, ann) for ann in predicted_annotations & gold_annotations] fn = [(document, ann) for ann in gold_annotations - predicted_annotations] fp = [(document, ann) for ann in predicted_annotations - gold_annotations] return tp, fp, fn def add_annotations(self, annotations: InstancesType, label: str): self.tp_fp_fn[label] = ( self.tp_fp_fn[label][0] + annotations[0], self.tp_fp_fn[label][1] + annotations[1], self.tp_fp_fn[label][2] + annotations[2], ) def _update(self, document: Document): new_tp_fp_fn = self.get_tp_fp_fn( document=document, annotation_filter=( partial(has_one_of_the_labels, label_field=self.label_field, labels=self.labels) if self.per_label and not self.infer_labels else None ), annotation_processor=self.annotation_processor, ) self.add_annotations(new_tp_fp_fn, label="MICRO") if self.infer_labels: layer = document[self.layer] # collect labels from gold data and predictions for ann in list(layer) + list(layer.predictions): label = getattr(ann, self.label_field) if label not in self.labels: self.labels.append(label) if self.per_label: for label in self.labels: new_tp_fp_fn = self.get_tp_fp_fn( document=document, annotation_filter=partial( has_this_label, label_field=self.label_field, label=label ), annotation_processor=self.annotation_processor, ) self.add_annotations(new_tp_fp_fn, label=label) def format_texts(self, texts: List[str]) -> str: return "".join(texts) def format_annotation(self, ann: Annotation) -> Dict[str, Any]: if isinstance(ann, RelatedRelation): head_resolved = ann.head.resolve() tail_resolved = ann.tail.resolve() ref_resolved = ann.reference_span.resolve() return { "related_label": ann.label, "related_score": round(ann.score, 3), "query_label": head_resolved[0], "query_texts": self.format_texts(head_resolved[1]), "query_score": round(ann.head.score, 3), "ref_label": ref_resolved[0], "ref_texts": self.format_texts(ref_resolved[1]), "ref_score": round(ann.reference_span.score, 3), "rec_label": tail_resolved[0], "rec_texts": self.format_texts(tail_resolved[1]), "rec_score": round(ann.tail.score, 3), } else: raise NotImplementedError # return ann.resolve() def format_instance(self, instance: InstanceType) -> Dict[str, Any]: document, annotation = instance result = self.format_annotation(annotation) if getattr(document, "id", None) is not None: result["document_id"] = document.id return result def _compute(self) -> Dict[str, Dict[str, list]]: res = dict() for k, instances in self.tp_fp_fn.items(): res[k] = { "tp": [self.format_instance(instance) for instance in instances[0]], "fp": [self.format_instance(instance) for instance in instances[1]], "fn": [self.format_instance(instance) for instance in instances[2]], } # if self.show_as_markdown: # logger.info(f"\n{self.layer}:\n{pd.DataFrame(res).round(3).T.to_markdown()}") return res