update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
import logging | |
from collections import defaultdict | |
from functools import partial | |
from typing import ( | |
Any, | |
Callable, | |
Collection, | |
Dict, | |
Hashable, | |
List, | |
Optional, | |
Tuple, | |
TypeAlias, | |
Union, | |
) | |
from pytorch_ie.core import Annotation, Document, DocumentMetric | |
from pytorch_ie.utils.hydra import resolve_target | |
from src.document.types import RelatedRelation | |
logger = logging.getLogger(__name__) | |
def has_one_of_the_labels(ann: Annotation, label_field: str, labels: Collection[str]) -> bool: | |
return getattr(ann, label_field) in labels | |
def has_this_label(ann: Annotation, label_field: str, label: str) -> bool: | |
return getattr(ann, label_field) == label | |
InstanceType: TypeAlias = Tuple[Document, Annotation] | |
InstancesType: TypeAlias = Tuple[List[InstanceType], List[InstanceType], List[InstanceType]] | |
class TPFFPFNMetric(DocumentMetric): | |
"""Computes the lists of True Positive, False Positive, and False Negative | |
annotations for a given layer. If labels are provided, it also computes | |
the counts for each label separately. | |
Works only with `RelatedRelation` annotations for now. | |
Args: | |
layer: The layer to compute the metrics for. | |
labels: If provided, calculate metrics for each label. | |
label_field: The field to use for the label. Defaults to "label". | |
""" | |
def __init__( | |
self, | |
layer: str, | |
labels: Optional[Union[Collection[str], str]] = None, | |
label_field: str = "label", | |
annotation_processor: Optional[Union[Callable[[Annotation], Hashable], str]] = None, | |
): | |
super().__init__() | |
self.layer = layer | |
self.label_field = label_field | |
self.annotation_processor: Optional[Callable[[Annotation], Hashable]] | |
if isinstance(annotation_processor, str): | |
self.annotation_processor = resolve_target(annotation_processor) | |
else: | |
self.annotation_processor = annotation_processor | |
self.per_label = labels is not None | |
self.infer_labels = False | |
if self.per_label: | |
if isinstance(labels, str): | |
if labels != "INFERRED": | |
raise ValueError( | |
"labels can only be 'INFERRED' if per_label is True and labels is a string" | |
) | |
self.labels = [] | |
self.infer_labels = True | |
elif isinstance(labels, Collection): | |
if not all(isinstance(label, str) for label in labels): | |
raise ValueError("labels must be a collection of strings") | |
if "MICRO" in labels or "MACRO" in labels: | |
raise ValueError( | |
"labels cannot contain 'MICRO' or 'MACRO' because they are used to capture aggregated metrics" | |
) | |
if len(labels) == 0: | |
raise ValueError("labels cannot be empty") | |
self.labels = list(labels) | |
else: | |
raise ValueError("labels must be a string or a collection of strings") | |
def reset(self): | |
self.tp_fp_fn = defaultdict(lambda: (list(), list(), list())) | |
def get_tp_fp_fn( | |
self, | |
document: Document, | |
annotation_filter: Optional[Callable[[Annotation], bool]] = None, | |
annotation_processor: Optional[Callable[[Annotation], Hashable]] = None, | |
) -> InstancesType: | |
annotation_processor = annotation_processor or (lambda ann: ann) | |
annotation_filter = annotation_filter or (lambda ann: True) | |
predicted_annotations = { | |
annotation_processor(ann) | |
for ann in document[self.layer].predictions | |
if annotation_filter(ann) | |
} | |
gold_annotations = { | |
annotation_processor(ann) for ann in document[self.layer] if annotation_filter(ann) | |
} | |
tp = [(document, ann) for ann in predicted_annotations & gold_annotations] | |
fn = [(document, ann) for ann in gold_annotations - predicted_annotations] | |
fp = [(document, ann) for ann in predicted_annotations - gold_annotations] | |
return tp, fp, fn | |
def add_annotations(self, annotations: InstancesType, label: str): | |
self.tp_fp_fn[label] = ( | |
self.tp_fp_fn[label][0] + annotations[0], | |
self.tp_fp_fn[label][1] + annotations[1], | |
self.tp_fp_fn[label][2] + annotations[2], | |
) | |
def _update(self, document: Document): | |
new_tp_fp_fn = self.get_tp_fp_fn( | |
document=document, | |
annotation_filter=( | |
partial(has_one_of_the_labels, label_field=self.label_field, labels=self.labels) | |
if self.per_label and not self.infer_labels | |
else None | |
), | |
annotation_processor=self.annotation_processor, | |
) | |
self.add_annotations(new_tp_fp_fn, label="MICRO") | |
if self.infer_labels: | |
layer = document[self.layer] | |
# collect labels from gold data and predictions | |
for ann in list(layer) + list(layer.predictions): | |
label = getattr(ann, self.label_field) | |
if label not in self.labels: | |
self.labels.append(label) | |
if self.per_label: | |
for label in self.labels: | |
new_tp_fp_fn = self.get_tp_fp_fn( | |
document=document, | |
annotation_filter=partial( | |
has_this_label, label_field=self.label_field, label=label | |
), | |
annotation_processor=self.annotation_processor, | |
) | |
self.add_annotations(new_tp_fp_fn, label=label) | |
def format_texts(self, texts: List[str]) -> str: | |
return "<SEP>".join(texts) | |
def format_annotation(self, ann: Annotation) -> Dict[str, Any]: | |
if isinstance(ann, RelatedRelation): | |
head_resolved = ann.head.resolve() | |
tail_resolved = ann.tail.resolve() | |
ref_resolved = ann.reference_span.resolve() | |
return { | |
"related_label": ann.label, | |
"related_score": round(ann.score, 3), | |
"query_label": head_resolved[0], | |
"query_texts": self.format_texts(head_resolved[1]), | |
"query_score": round(ann.head.score, 3), | |
"ref_label": ref_resolved[0], | |
"ref_texts": self.format_texts(ref_resolved[1]), | |
"ref_score": round(ann.reference_span.score, 3), | |
"rec_label": tail_resolved[0], | |
"rec_texts": self.format_texts(tail_resolved[1]), | |
"rec_score": round(ann.tail.score, 3), | |
} | |
else: | |
raise NotImplementedError | |
# return ann.resolve() | |
def format_instance(self, instance: InstanceType) -> Dict[str, Any]: | |
document, annotation = instance | |
result = self.format_annotation(annotation) | |
if getattr(document, "id", None) is not None: | |
result["document_id"] = document.id | |
return result | |
def _compute(self) -> Dict[str, Dict[str, list]]: | |
res = dict() | |
for k, instances in self.tp_fp_fn.items(): | |
res[k] = { | |
"tp": [self.format_instance(instance) for instance in instances[0]], | |
"fp": [self.format_instance(instance) for instance in instances[1]], | |
"fn": [self.format_instance(instance) for instance in instances[2]], | |
} | |
# if self.show_as_markdown: | |
# logger.info(f"\n{self.layer}:\n{pd.DataFrame(res).round(3).T.to_markdown()}") | |
return res | |