update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
import logging | |
from collections import defaultdict | |
from typing import Callable, Dict, List, Optional, Sequence, Union | |
from pandas import MultiIndex | |
from pytorch_ie import Annotation, AnnotationLayer, Document, DocumentMetric | |
from pytorch_ie.annotations import BinaryRelation | |
from pytorch_ie.core.metric import T | |
from pytorch_ie.utils.hydra import resolve_target | |
from src.hydra_callbacks.save_job_return_value import to_py_obj | |
logger = logging.getLogger(__name__) | |
class RankingMetricsSKLearn(DocumentMetric): | |
"""Ranking metrics for documents with binary relations. | |
This metric computes the ranking metrics for retrieval tasks, where | |
relation heads are the queries and the relation tails are the candidates. | |
The metric is computed for each head and the results are averaged. It is meant to | |
be used with Scikit-learn metrics such as `sklearn.metrics.ndcg_score` (Normalized | |
Discounted Cumulative Gain), `sklearn.metrics.label_ranking_average_precision_score` | |
(LRAP), etc., see | |
https://scikit-learn.org/stable/modules/model_evaluation.html#multilabel-ranking-metrics. | |
Args: | |
metrics (Dict[str, Union[str, Callable]]): A dictionary of metric names and their | |
corresponding functions. The function can be a string (name of the function, e.g., | |
sklearn.metrics.ndcg_score) or a callable. | |
layer (str): The name of the annotation layer containing the binary relations, e.g., | |
"binary_relations" when applied to TextDocumentsWithLabeledSpansAndBinaryRelations. | |
use_manual_average (Optional[List[str]]): A list of metric names to use for manual | |
averaging. If provided, the metric scores will be calculated for each | |
head and then averaged. Otherwise, all true and predicted scores will be | |
passed to the metric function at once. | |
exclude_singletons (Optional[List[str]]): A list of metric names to exclude singletons | |
from the computation, i.e., entries (heads) where the number of candidates is 1. | |
label (Optional[str]): If provided, only the relations with this label will be used | |
to compute the metrics. This is useful for filtering out relations that are not | |
relevant for the task at hand (e.g., when having multiple relation types in the | |
same layer). | |
score_threshold (float): If provided, only the relations with a score greater than or | |
equal to this threshold will be used to compute the metrics. | |
default_score (float): The default score to use for missing relations, either in the | |
target or prediction. Default is 0.0. | |
use_all_spans (bool): Whether to consider all spans in the document as queries and | |
candidates or only the spans that are present in the target and prediction. | |
span_label_blacklist (Optional[List[str]]): If provided, ignore the relations with | |
heads/tails that are in this list. When using use_all_spans=True, this also | |
restricts the candidates to those that are not in the blacklist. | |
show_as_markdown (bool): Whether to show the results as markdown. Default is False. | |
markdown_precision (int): The precision for displaying the results in markdown. | |
Default is 4. | |
""" | |
def __init__( | |
self, | |
metrics: Dict[str, Union[str, Callable]], | |
layer: str, | |
use_manual_average: Optional[List[str]] = None, | |
exclude_singletons: Optional[List[str]] = None, | |
label: Optional[str] = None, | |
score_threshold: float = 0.0, | |
default_score: float = 0.0, | |
use_all_spans: bool = False, | |
span_label_blacklist: Optional[List[str]] = None, | |
show_as_markdown: bool = False, | |
markdown_precision: int = 4, | |
plot: bool = False, | |
): | |
self.metrics = { | |
name: resolve_target(metric) if isinstance(metric, str) else metric | |
for name, metric in metrics.items() | |
} | |
self.use_manual_average = set(use_manual_average or []) | |
self.exclude_singletons = set(exclude_singletons or []) | |
self.annotation_layer_name = layer | |
self.annotation_label = label | |
self.score_threshold = score_threshold | |
self.default_score = default_score | |
self.use_all_spans = use_all_spans | |
self.span_label_blacklist = span_label_blacklist | |
self.show_as_markdown = show_as_markdown | |
self.markdown_precision = markdown_precision | |
self.plot = plot | |
super().__init__() | |
def reset(self) -> None: | |
self._preds: List[List[float]] = [] | |
self._targets: List[List[float]] = [] | |
def get_head2tail2score( | |
self, relations: Sequence[BinaryRelation] | |
) -> Dict[Annotation, Dict[Annotation, float]]: | |
result: Dict[Annotation, Dict[Annotation, float]] = defaultdict(dict) | |
for rel in relations: | |
if ( | |
(self.annotation_label is None or rel.label == self.annotation_label) | |
and (rel.score >= self.score_threshold) | |
and ( | |
self.span_label_blacklist is None | |
or ( | |
rel.head.label not in self.span_label_blacklist | |
and rel.tail.label not in self.span_label_blacklist | |
) | |
) | |
): | |
result[rel.head][rel.tail] = rel.score | |
return result | |
def _update(self, document: Document) -> None: | |
annotation_layer: AnnotationLayer[BinaryRelation] = document[self.annotation_layer_name] | |
target_head2tail2score = self.get_head2tail2score(annotation_layer) | |
prediction_head2tail2score = self.get_head2tail2score(annotation_layer.predictions) | |
all_spans = set() | |
# get spans from all layers targeted by the annotation (relation) layer | |
for span_layer in annotation_layer.target_layers.values(): | |
all_spans.update(span_layer) | |
if self.span_label_blacklist is not None: | |
all_spans = {span for span in all_spans if span.label not in self.span_label_blacklist} | |
if self.use_all_spans: | |
all_heads = all_spans | |
else: | |
all_heads = set(target_head2tail2score) | set(prediction_head2tail2score) | |
all_targets: List[List[float]] = [] | |
all_predictions: List[List[float]] = [] | |
for head in all_heads: | |
target_tail2score = target_head2tail2score.get(head, {}) | |
prediction_tail2score = prediction_head2tail2score.get(head, {}) | |
if self.use_all_spans: | |
# use all spans as tails | |
tails = set(span for span in all_spans if span != head) | |
else: | |
# use only the tails that are in the target or prediction | |
tails = set(target_tail2score) | set(prediction_tail2score) | |
target_scores = [target_tail2score.get(t, self.default_score) for t in tails] | |
prediction_scores = [prediction_tail2score.get(t, self.default_score) for t in tails] | |
all_targets.append(target_scores) | |
all_predictions.append(prediction_scores) | |
self._targets.extend(all_targets) | |
self._preds.extend(all_predictions) | |
def do_plot(self): | |
raise NotImplementedError() | |
def _compute(self) -> T: | |
if self.plot: | |
self.do_plot() | |
result = {} | |
for name, metric in self.metrics.items(): | |
targets, preds = self._targets, self._preds | |
if name in self.exclude_singletons: | |
targets = [t for t in targets if len(t) > 1] | |
preds = [p for p in preds if len(p) > 1] | |
num_singletons = len(self._targets) - len(targets) | |
logger.warning( | |
f"Excluding {num_singletons} singletons (out of {len(self._targets)} " | |
f"entries) from {name} metric calculation." | |
) | |
if name in self.use_manual_average: | |
scores = [ | |
metric(y_true=[tgts], y_score=[prds]) for tgts, prds in zip(targets, preds) | |
] | |
result[name] = sum(scores) / len(scores) if len(scores) > 0 else 0.0 | |
else: | |
result[name] = metric(y_true=targets, y_score=preds) | |
result = to_py_obj(result) | |
if self.show_as_markdown: | |
import pandas as pd | |
series = pd.Series(result) | |
if isinstance(series.index, MultiIndex): | |
if len(series.index.levels) > 1: | |
# in fact, this is not a series anymore | |
series = series.unstack(-1) | |
else: | |
series.index = series.index.get_level_values(0) | |
logger.info( | |
f"{self.current_split}\n{series.round(self.markdown_precision).to_markdown()}" | |
) | |
return result | |