File size: 8,925 Bytes
d868d2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import logging
from collections import defaultdict
from typing import Callable, Dict, List, Optional, Sequence, Union

from pandas import MultiIndex
from pytorch_ie import Annotation, AnnotationLayer, Document, DocumentMetric
from pytorch_ie.annotations import BinaryRelation
from pytorch_ie.core.metric import T
from pytorch_ie.utils.hydra import resolve_target

from src.hydra_callbacks.save_job_return_value import to_py_obj

logger = logging.getLogger(__name__)


class RankingMetricsSKLearn(DocumentMetric):
    """Ranking metrics for documents with binary relations.

    This metric computes the ranking metrics for retrieval tasks, where
    relation heads are the queries and the relation tails are the candidates.
    The metric is computed for each head and the results are averaged. It is meant to
    be used with Scikit-learn metrics such as `sklearn.metrics.ndcg_score` (Normalized
    Discounted Cumulative Gain), `sklearn.metrics.label_ranking_average_precision_score`
    (LRAP), etc., see
    https://scikit-learn.org/stable/modules/model_evaluation.html#multilabel-ranking-metrics.

    Args:
        metrics (Dict[str, Union[str, Callable]]): A dictionary of metric names and their
            corresponding functions. The function can be a string (name of the function, e.g.,
            sklearn.metrics.ndcg_score) or a callable.
        layer (str): The name of the annotation layer containing the binary relations, e.g.,
            "binary_relations" when applied to TextDocumentsWithLabeledSpansAndBinaryRelations.
        use_manual_average (Optional[List[str]]): A list of metric names to use for manual
            averaging. If provided, the metric scores will be calculated for each
            head and then averaged. Otherwise, all true and predicted scores will be
            passed to the metric function at once.
        exclude_singletons (Optional[List[str]]): A list of metric names to exclude singletons
            from the computation, i.e., entries (heads) where the number of candidates is 1.
        label (Optional[str]): If provided, only the relations with this label will be used
            to compute the metrics. This is useful for filtering out relations that are not
            relevant for the task at hand (e.g., when having multiple relation types in the
            same layer).
        score_threshold (float): If provided, only the relations with a score greater than or
            equal to this threshold will be used to compute the metrics.
        default_score (float): The default score to use for missing relations, either in the
            target or prediction. Default is 0.0.
        use_all_spans (bool): Whether to consider all spans in the document as queries and
            candidates or only the spans that are present in the target and prediction.
        span_label_blacklist (Optional[List[str]]): If provided, ignore the relations with
            heads/tails that are in this list. When using use_all_spans=True, this also
            restricts the candidates to those that are not in the blacklist.
        show_as_markdown (bool): Whether to show the results as markdown. Default is False.
        markdown_precision (int): The precision for displaying the results in markdown.
            Default is 4.
    """

    def __init__(
        self,
        metrics: Dict[str, Union[str, Callable]],
        layer: str,
        use_manual_average: Optional[List[str]] = None,
        exclude_singletons: Optional[List[str]] = None,
        label: Optional[str] = None,
        score_threshold: float = 0.0,
        default_score: float = 0.0,
        use_all_spans: bool = False,
        span_label_blacklist: Optional[List[str]] = None,
        show_as_markdown: bool = False,
        markdown_precision: int = 4,
        plot: bool = False,
    ):
        self.metrics = {
            name: resolve_target(metric) if isinstance(metric, str) else metric
            for name, metric in metrics.items()
        }
        self.use_manual_average = set(use_manual_average or [])
        self.exclude_singletons = set(exclude_singletons or [])
        self.annotation_layer_name = layer
        self.annotation_label = label
        self.score_threshold = score_threshold
        self.default_score = default_score
        self.use_all_spans = use_all_spans
        self.span_label_blacklist = span_label_blacklist
        self.show_as_markdown = show_as_markdown
        self.markdown_precision = markdown_precision
        self.plot = plot

        super().__init__()

    def reset(self) -> None:
        self._preds: List[List[float]] = []
        self._targets: List[List[float]] = []

    def get_head2tail2score(
        self, relations: Sequence[BinaryRelation]
    ) -> Dict[Annotation, Dict[Annotation, float]]:
        result: Dict[Annotation, Dict[Annotation, float]] = defaultdict(dict)
        for rel in relations:
            if (
                (self.annotation_label is None or rel.label == self.annotation_label)
                and (rel.score >= self.score_threshold)
                and (
                    self.span_label_blacklist is None
                    or (
                        rel.head.label not in self.span_label_blacklist
                        and rel.tail.label not in self.span_label_blacklist
                    )
                )
            ):
                result[rel.head][rel.tail] = rel.score

        return result

    def _update(self, document: Document) -> None:
        annotation_layer: AnnotationLayer[BinaryRelation] = document[self.annotation_layer_name]

        target_head2tail2score = self.get_head2tail2score(annotation_layer)
        prediction_head2tail2score = self.get_head2tail2score(annotation_layer.predictions)
        all_spans = set()
        # get spans from all layers targeted by the annotation (relation) layer
        for span_layer in annotation_layer.target_layers.values():
            all_spans.update(span_layer)

        if self.span_label_blacklist is not None:
            all_spans = {span for span in all_spans if span.label not in self.span_label_blacklist}

        if self.use_all_spans:
            all_heads = all_spans
        else:
            all_heads = set(target_head2tail2score) | set(prediction_head2tail2score)

        all_targets: List[List[float]] = []
        all_predictions: List[List[float]] = []
        for head in all_heads:
            target_tail2score = target_head2tail2score.get(head, {})
            prediction_tail2score = prediction_head2tail2score.get(head, {})
            if self.use_all_spans:
                # use all spans as tails
                tails = set(span for span in all_spans if span != head)
            else:
                # use only the tails that are in the target or prediction
                tails = set(target_tail2score) | set(prediction_tail2score)
            target_scores = [target_tail2score.get(t, self.default_score) for t in tails]
            prediction_scores = [prediction_tail2score.get(t, self.default_score) for t in tails]
            all_targets.append(target_scores)
            all_predictions.append(prediction_scores)

        self._targets.extend(all_targets)
        self._preds.extend(all_predictions)

    def do_plot(self):
        raise NotImplementedError()

    def _compute(self) -> T:

        if self.plot:
            self.do_plot()

        result = {}
        for name, metric in self.metrics.items():
            targets, preds = self._targets, self._preds
            if name in self.exclude_singletons:
                targets = [t for t in targets if len(t) > 1]
                preds = [p for p in preds if len(p) > 1]
                num_singletons = len(self._targets) - len(targets)
                logger.warning(
                    f"Excluding {num_singletons} singletons (out of {len(self._targets)} "
                    f"entries) from {name} metric calculation."
                )

            if name in self.use_manual_average:
                scores = [
                    metric(y_true=[tgts], y_score=[prds]) for tgts, prds in zip(targets, preds)
                ]
                result[name] = sum(scores) / len(scores) if len(scores) > 0 else 0.0
            else:
                result[name] = metric(y_true=targets, y_score=preds)

        result = to_py_obj(result)
        if self.show_as_markdown:
            import pandas as pd

            series = pd.Series(result)
            if isinstance(series.index, MultiIndex):
                if len(series.index.levels) > 1:
                    # in fact, this is not a series anymore
                    series = series.unstack(-1)
                else:
                    series.index = series.index.get_level_values(0)
            logger.info(
                f"{self.current_split}\n{series.round(self.markdown_precision).to_markdown()}"
            )
        return result