import logging from collections import defaultdict from typing import Callable, Dict, List, Optional, Sequence, Union from pandas import MultiIndex from pytorch_ie import Annotation, AnnotationLayer, Document, DocumentMetric from pytorch_ie.annotations import BinaryRelation from pytorch_ie.core.metric import T from pytorch_ie.utils.hydra import resolve_target from src.hydra_callbacks.save_job_return_value import to_py_obj logger = logging.getLogger(__name__) class RankingMetricsSKLearn(DocumentMetric): """Ranking metrics for documents with binary relations. This metric computes the ranking metrics for retrieval tasks, where relation heads are the queries and the relation tails are the candidates. The metric is computed for each head and the results are averaged. It is meant to be used with Scikit-learn metrics such as `sklearn.metrics.ndcg_score` (Normalized Discounted Cumulative Gain), `sklearn.metrics.label_ranking_average_precision_score` (LRAP), etc., see https://scikit-learn.org/stable/modules/model_evaluation.html#multilabel-ranking-metrics. Args: metrics (Dict[str, Union[str, Callable]]): A dictionary of metric names and their corresponding functions. The function can be a string (name of the function, e.g., sklearn.metrics.ndcg_score) or a callable. layer (str): The name of the annotation layer containing the binary relations, e.g., "binary_relations" when applied to TextDocumentsWithLabeledSpansAndBinaryRelations. use_manual_average (Optional[List[str]]): A list of metric names to use for manual averaging. If provided, the metric scores will be calculated for each head and then averaged. Otherwise, all true and predicted scores will be passed to the metric function at once. exclude_singletons (Optional[List[str]]): A list of metric names to exclude singletons from the computation, i.e., entries (heads) where the number of candidates is 1. label (Optional[str]): If provided, only the relations with this label will be used to compute the metrics. This is useful for filtering out relations that are not relevant for the task at hand (e.g., when having multiple relation types in the same layer). score_threshold (float): If provided, only the relations with a score greater than or equal to this threshold will be used to compute the metrics. default_score (float): The default score to use for missing relations, either in the target or prediction. Default is 0.0. use_all_spans (bool): Whether to consider all spans in the document as queries and candidates or only the spans that are present in the target and prediction. span_label_blacklist (Optional[List[str]]): If provided, ignore the relations with heads/tails that are in this list. When using use_all_spans=True, this also restricts the candidates to those that are not in the blacklist. show_as_markdown (bool): Whether to show the results as markdown. Default is False. markdown_precision (int): The precision for displaying the results in markdown. Default is 4. """ def __init__( self, metrics: Dict[str, Union[str, Callable]], layer: str, use_manual_average: Optional[List[str]] = None, exclude_singletons: Optional[List[str]] = None, label: Optional[str] = None, score_threshold: float = 0.0, default_score: float = 0.0, use_all_spans: bool = False, span_label_blacklist: Optional[List[str]] = None, show_as_markdown: bool = False, markdown_precision: int = 4, plot: bool = False, ): self.metrics = { name: resolve_target(metric) if isinstance(metric, str) else metric for name, metric in metrics.items() } self.use_manual_average = set(use_manual_average or []) self.exclude_singletons = set(exclude_singletons or []) self.annotation_layer_name = layer self.annotation_label = label self.score_threshold = score_threshold self.default_score = default_score self.use_all_spans = use_all_spans self.span_label_blacklist = span_label_blacklist self.show_as_markdown = show_as_markdown self.markdown_precision = markdown_precision self.plot = plot super().__init__() def reset(self) -> None: self._preds: List[List[float]] = [] self._targets: List[List[float]] = [] def get_head2tail2score( self, relations: Sequence[BinaryRelation] ) -> Dict[Annotation, Dict[Annotation, float]]: result: Dict[Annotation, Dict[Annotation, float]] = defaultdict(dict) for rel in relations: if ( (self.annotation_label is None or rel.label == self.annotation_label) and (rel.score >= self.score_threshold) and ( self.span_label_blacklist is None or ( rel.head.label not in self.span_label_blacklist and rel.tail.label not in self.span_label_blacklist ) ) ): result[rel.head][rel.tail] = rel.score return result def _update(self, document: Document) -> None: annotation_layer: AnnotationLayer[BinaryRelation] = document[self.annotation_layer_name] target_head2tail2score = self.get_head2tail2score(annotation_layer) prediction_head2tail2score = self.get_head2tail2score(annotation_layer.predictions) all_spans = set() # get spans from all layers targeted by the annotation (relation) layer for span_layer in annotation_layer.target_layers.values(): all_spans.update(span_layer) if self.span_label_blacklist is not None: all_spans = {span for span in all_spans if span.label not in self.span_label_blacklist} if self.use_all_spans: all_heads = all_spans else: all_heads = set(target_head2tail2score) | set(prediction_head2tail2score) all_targets: List[List[float]] = [] all_predictions: List[List[float]] = [] for head in all_heads: target_tail2score = target_head2tail2score.get(head, {}) prediction_tail2score = prediction_head2tail2score.get(head, {}) if self.use_all_spans: # use all spans as tails tails = set(span for span in all_spans if span != head) else: # use only the tails that are in the target or prediction tails = set(target_tail2score) | set(prediction_tail2score) target_scores = [target_tail2score.get(t, self.default_score) for t in tails] prediction_scores = [prediction_tail2score.get(t, self.default_score) for t in tails] all_targets.append(target_scores) all_predictions.append(prediction_scores) self._targets.extend(all_targets) self._preds.extend(all_predictions) def do_plot(self): raise NotImplementedError() def _compute(self) -> T: if self.plot: self.do_plot() result = {} for name, metric in self.metrics.items(): targets, preds = self._targets, self._preds if name in self.exclude_singletons: targets = [t for t in targets if len(t) > 1] preds = [p for p in preds if len(p) > 1] num_singletons = len(self._targets) - len(targets) logger.warning( f"Excluding {num_singletons} singletons (out of {len(self._targets)} " f"entries) from {name} metric calculation." ) if name in self.use_manual_average: scores = [ metric(y_true=[tgts], y_score=[prds]) for tgts, prds in zip(targets, preds) ] result[name] = sum(scores) / len(scores) if len(scores) > 0 else 0.0 else: result[name] = metric(y_true=targets, y_score=preds) result = to_py_obj(result) if self.show_as_markdown: import pandas as pd series = pd.Series(result) if isinstance(series.index, MultiIndex): if len(series.index.levels) > 1: # in fact, this is not a series anymore series = series.unstack(-1) else: series.index = series.index.get_level_values(0) logger.info( f"{self.current_split}\n{series.round(self.markdown_precision).to_markdown()}" ) return result