from collections import Counter from typing import Dict, Hashable, List, Optional, Sequence, Tuple, TypeVar import numpy as np from pytorch_ie import Annotation, Document, DocumentMetric from pytorch_ie.annotations import BinaryRelation from src.utils.graph_utils import get_connected_components class CorefHoiEvaluator(object): def __init__(self, metric, beta=1): self.p_num = 0 self.p_den = 0 self.r_num = 0 self.r_den = 0 self.metric = metric self.beta = beta def update(self, predicted, gold, mention_to_predicted, mention_to_gold): if self.metric == ceafe_simplified: pn, pd, rn, rd = self.metric(predicted, gold) else: pn, pd = self.metric(predicted, mention_to_gold) rn, rd = self.metric(gold, mention_to_predicted) self.p_num += pn self.p_den += pd self.r_num += rn self.r_den += rd def f1(self, p_num, p_den, r_num, r_den, beta=1): p = 0 if p_den == 0 else p_num / float(p_den) r = 0 if r_den == 0 else r_num / float(r_den) return 0 if p + r == 0 else (1 + beta * beta) * p * r / (beta * beta * p + r) def get_f1(self): return self.f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta) def get_recall(self): return 0 if self.r_num == 0 else self.r_num / float(self.r_den) def get_precision(self): return 0 if self.p_num == 0 else self.p_num / float(self.p_den) def get_prf(self): return self.get_precision(), self.get_recall(), self.get_f1() def get_counts(self): return self.p_num, self.p_den, self.r_num, self.r_den def b_cubed_simplified(clusters, mention_to_gold): num, dem = 0, 0 for c in clusters: if len(c) == 1: continue gold_counts = Counter() correct = 0 for m in c: if m in mention_to_gold: gold_counts[tuple(mention_to_gold[m])] += 1 for c2, count in gold_counts.items(): if len(c2) != 1: correct += count * count num += correct / float(len(c)) dem += len(c) return num, dem def muc_simplified(clusters, mention_to_gold): tp, p = 0, 0 for c in clusters: p += len(c) - 1 tp += len(c) linked = set() for m in c: if m in mention_to_gold: linked.add(mention_to_gold[m]) else: tp -= 1 tp -= len(linked) return tp, p def phi4_simplified(c1, c2): return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2)) def ceafe_simplified(clusters, gold_clusters): # lazy import to not force scipy installation from scipy.optimize import linear_sum_assignment as linear_assignment clusters = [c for c in clusters if len(c) != 1] scores = np.zeros((len(gold_clusters), len(clusters))) for i in range(len(gold_clusters)): for j in range(len(clusters)): scores[i, j] = phi4_simplified(gold_clusters[i], clusters[j]) matching = linear_assignment(-scores) matching = np.transpose(np.asarray(matching)) similarity = sum(scores[matching[:, 0], matching[:, 1]]) return similarity, len(clusters), similarity, len(gold_clusters) def lea_simplified(clusters, mention_to_gold): num, dem = 0, 0 for c in clusters: if len(c) == 1: continue common_links = 0 all_links = len(c) * (len(c) - 1) / 2.0 for i, m in enumerate(c): if m in mention_to_gold: for m2 in c[i + 1 :]: if m2 in mention_to_gold and mention_to_gold[m] == mention_to_gold[m2]: common_links += 1 num += len(c) * common_links / float(all_links) dem += len(c) return num, dem H = TypeVar("H", bound=Hashable) class CorefHoiF1(DocumentMetric): """ Coreference evaluation based on official coref-hoi evaluation script, i.e., https://github.com/lxucs/coref-hoi/blob/5ddfc3b64a5519c3555b5a57e47ab2f03c104a60/metrics.py. The metric expects documents with a relation layer that contains binary relations between mentions from the same coreference cluster. Works with relations targeting mentions from multiple layers (e.g., cross-textual relations). Args: relation_layer: The name of the relation layer that contains the link relations. include_singletons: If True (default), singletons will be included in the evaluation. link_relation_label: If provided, only the relations with this label will be used to create the clusters. link_relation_relation_score_threshold: If provided, only the relations with a score greater than or equal to this threshold will be used to create the clusters. """ def __init__( self, relation_layer: str, include_singletons: bool = True, link_relation_label: Optional[str] = None, link_relation_relation_score_threshold: Optional[float] = None, ) -> None: super().__init__() self.relation_layer = relation_layer self.link_relation_label = link_relation_label self.include_singletons = include_singletons self.link_relation_relation_score_threshold = link_relation_relation_score_threshold def reset(self) -> None: self.evaluators = [ CorefHoiEvaluator(m) for m in (muc_simplified, b_cubed_simplified, ceafe_simplified) ] def prepare_clusters_with_mapping( self, mentions: Sequence[Annotation], relations: Sequence[BinaryRelation] ) -> Tuple[List[List[Annotation]], Dict[Annotation, Tuple[Annotation]]]: # get connected components based on binary relations connected_components = get_connected_components( elements=mentions, relations=relations, link_relation_label=self.link_relation_label, link_relation_relation_score_threshold=self.link_relation_relation_score_threshold, add_singletons=self.include_singletons, ) # store all clustered mentions in a list and # create a map from each mention to its cluster # (i.e. to the list of spans that includes all other mentions from the same cluster) clusters = [] mention_to_cluster = dict() for cluster in connected_components: clusters.append(cluster) for mention in cluster: mention_to_cluster[mention] = tuple(cluster) return clusters, mention_to_cluster def _update(self, doc: Document) -> None: relation_layer = doc[self.relation_layer] gold_mentions = [] predicted_mentions = [] for mention_layer in relation_layer.target_layers.values(): gold_mentions.extend(mention_layer) predicted_mentions.extend(mention_layer.predictions) # prepare the clusters and mention-to-cluster mapping needed for evaluation predicted_clusters, mention_to_predicted = self.prepare_clusters_with_mapping( mentions=predicted_mentions, relations=relation_layer.predictions ) gold_clusters, mention_to_gold = self.prepare_clusters_with_mapping( mentions=gold_mentions, relations=relation_layer ) for e in self.evaluators: e.update(predicted_clusters, gold_clusters, mention_to_predicted, mention_to_gold) def get_f1(self) -> float: return sum(e.get_f1() for e in self.evaluators) / len(self.evaluators) def get_recall(self) -> float: return sum(e.get_recall() for e in self.evaluators) / len(self.evaluators) def get_precision(self) -> float: return sum(e.get_precision() for e in self.evaluators) / len(self.evaluators) def get_prf(self) -> Tuple[float, float, float]: return self.get_precision(), self.get_recall(), self.get_f1() def _compute(self) -> float: return self.get_f1()