ArneBinder's picture
update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e verified
import logging
from collections import defaultdict
from functools import partial
from typing import (
Any,
Callable,
Collection,
Dict,
Hashable,
List,
Optional,
Tuple,
TypeAlias,
Union,
)
from pytorch_ie.core import Annotation, Document, DocumentMetric
from pytorch_ie.utils.hydra import resolve_target
from src.document.types import RelatedRelation
logger = logging.getLogger(__name__)
def has_one_of_the_labels(ann: Annotation, label_field: str, labels: Collection[str]) -> bool:
return getattr(ann, label_field) in labels
def has_this_label(ann: Annotation, label_field: str, label: str) -> bool:
return getattr(ann, label_field) == label
InstanceType: TypeAlias = Tuple[Document, Annotation]
InstancesType: TypeAlias = Tuple[List[InstanceType], List[InstanceType], List[InstanceType]]
class TPFFPFNMetric(DocumentMetric):
"""Computes the lists of True Positive, False Positive, and False Negative
annotations for a given layer. If labels are provided, it also computes
the counts for each label separately.
Works only with `RelatedRelation` annotations for now.
Args:
layer: The layer to compute the metrics for.
labels: If provided, calculate metrics for each label.
label_field: The field to use for the label. Defaults to "label".
"""
def __init__(
self,
layer: str,
labels: Optional[Union[Collection[str], str]] = None,
label_field: str = "label",
annotation_processor: Optional[Union[Callable[[Annotation], Hashable], str]] = None,
):
super().__init__()
self.layer = layer
self.label_field = label_field
self.annotation_processor: Optional[Callable[[Annotation], Hashable]]
if isinstance(annotation_processor, str):
self.annotation_processor = resolve_target(annotation_processor)
else:
self.annotation_processor = annotation_processor
self.per_label = labels is not None
self.infer_labels = False
if self.per_label:
if isinstance(labels, str):
if labels != "INFERRED":
raise ValueError(
"labels can only be 'INFERRED' if per_label is True and labels is a string"
)
self.labels = []
self.infer_labels = True
elif isinstance(labels, Collection):
if not all(isinstance(label, str) for label in labels):
raise ValueError("labels must be a collection of strings")
if "MICRO" in labels or "MACRO" in labels:
raise ValueError(
"labels cannot contain 'MICRO' or 'MACRO' because they are used to capture aggregated metrics"
)
if len(labels) == 0:
raise ValueError("labels cannot be empty")
self.labels = list(labels)
else:
raise ValueError("labels must be a string or a collection of strings")
def reset(self):
self.tp_fp_fn = defaultdict(lambda: (list(), list(), list()))
def get_tp_fp_fn(
self,
document: Document,
annotation_filter: Optional[Callable[[Annotation], bool]] = None,
annotation_processor: Optional[Callable[[Annotation], Hashable]] = None,
) -> InstancesType:
annotation_processor = annotation_processor or (lambda ann: ann)
annotation_filter = annotation_filter or (lambda ann: True)
predicted_annotations = {
annotation_processor(ann)
for ann in document[self.layer].predictions
if annotation_filter(ann)
}
gold_annotations = {
annotation_processor(ann) for ann in document[self.layer] if annotation_filter(ann)
}
tp = [(document, ann) for ann in predicted_annotations & gold_annotations]
fn = [(document, ann) for ann in gold_annotations - predicted_annotations]
fp = [(document, ann) for ann in predicted_annotations - gold_annotations]
return tp, fp, fn
def add_annotations(self, annotations: InstancesType, label: str):
self.tp_fp_fn[label] = (
self.tp_fp_fn[label][0] + annotations[0],
self.tp_fp_fn[label][1] + annotations[1],
self.tp_fp_fn[label][2] + annotations[2],
)
def _update(self, document: Document):
new_tp_fp_fn = self.get_tp_fp_fn(
document=document,
annotation_filter=(
partial(has_one_of_the_labels, label_field=self.label_field, labels=self.labels)
if self.per_label and not self.infer_labels
else None
),
annotation_processor=self.annotation_processor,
)
self.add_annotations(new_tp_fp_fn, label="MICRO")
if self.infer_labels:
layer = document[self.layer]
# collect labels from gold data and predictions
for ann in list(layer) + list(layer.predictions):
label = getattr(ann, self.label_field)
if label not in self.labels:
self.labels.append(label)
if self.per_label:
for label in self.labels:
new_tp_fp_fn = self.get_tp_fp_fn(
document=document,
annotation_filter=partial(
has_this_label, label_field=self.label_field, label=label
),
annotation_processor=self.annotation_processor,
)
self.add_annotations(new_tp_fp_fn, label=label)
def format_texts(self, texts: List[str]) -> str:
return "<SEP>".join(texts)
def format_annotation(self, ann: Annotation) -> Dict[str, Any]:
if isinstance(ann, RelatedRelation):
head_resolved = ann.head.resolve()
tail_resolved = ann.tail.resolve()
ref_resolved = ann.reference_span.resolve()
return {
"related_label": ann.label,
"related_score": round(ann.score, 3),
"query_label": head_resolved[0],
"query_texts": self.format_texts(head_resolved[1]),
"query_score": round(ann.head.score, 3),
"ref_label": ref_resolved[0],
"ref_texts": self.format_texts(ref_resolved[1]),
"ref_score": round(ann.reference_span.score, 3),
"rec_label": tail_resolved[0],
"rec_texts": self.format_texts(tail_resolved[1]),
"rec_score": round(ann.tail.score, 3),
}
else:
raise NotImplementedError
# return ann.resolve()
def format_instance(self, instance: InstanceType) -> Dict[str, Any]:
document, annotation = instance
result = self.format_annotation(annotation)
if getattr(document, "id", None) is not None:
result["document_id"] = document.id
return result
def _compute(self) -> Dict[str, Dict[str, list]]:
res = dict()
for k, instances in self.tp_fp_fn.items():
res[k] = {
"tp": [self.format_instance(instance) for instance in instances[0]],
"fp": [self.format_instance(instance) for instance in instances[1]],
"fn": [self.format_instance(instance) for instance in instances[2]],
}
# if self.show_as_markdown:
# logger.info(f"\n{self.layer}:\n{pd.DataFrame(res).round(3).T.to_markdown()}")
return res