update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
from collections import Counter | |
from typing import Dict, Hashable, List, Optional, Sequence, Tuple, TypeVar | |
import numpy as np | |
from pytorch_ie import Annotation, Document, DocumentMetric | |
from pytorch_ie.annotations import BinaryRelation | |
from src.utils.graph_utils import get_connected_components | |
class CorefHoiEvaluator(object): | |
def __init__(self, metric, beta=1): | |
self.p_num = 0 | |
self.p_den = 0 | |
self.r_num = 0 | |
self.r_den = 0 | |
self.metric = metric | |
self.beta = beta | |
def update(self, predicted, gold, mention_to_predicted, mention_to_gold): | |
if self.metric == ceafe_simplified: | |
pn, pd, rn, rd = self.metric(predicted, gold) | |
else: | |
pn, pd = self.metric(predicted, mention_to_gold) | |
rn, rd = self.metric(gold, mention_to_predicted) | |
self.p_num += pn | |
self.p_den += pd | |
self.r_num += rn | |
self.r_den += rd | |
def f1(self, p_num, p_den, r_num, r_den, beta=1): | |
p = 0 if p_den == 0 else p_num / float(p_den) | |
r = 0 if r_den == 0 else r_num / float(r_den) | |
return 0 if p + r == 0 else (1 + beta * beta) * p * r / (beta * beta * p + r) | |
def get_f1(self): | |
return self.f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta) | |
def get_recall(self): | |
return 0 if self.r_num == 0 else self.r_num / float(self.r_den) | |
def get_precision(self): | |
return 0 if self.p_num == 0 else self.p_num / float(self.p_den) | |
def get_prf(self): | |
return self.get_precision(), self.get_recall(), self.get_f1() | |
def get_counts(self): | |
return self.p_num, self.p_den, self.r_num, self.r_den | |
def b_cubed_simplified(clusters, mention_to_gold): | |
num, dem = 0, 0 | |
for c in clusters: | |
if len(c) == 1: | |
continue | |
gold_counts = Counter() | |
correct = 0 | |
for m in c: | |
if m in mention_to_gold: | |
gold_counts[tuple(mention_to_gold[m])] += 1 | |
for c2, count in gold_counts.items(): | |
if len(c2) != 1: | |
correct += count * count | |
num += correct / float(len(c)) | |
dem += len(c) | |
return num, dem | |
def muc_simplified(clusters, mention_to_gold): | |
tp, p = 0, 0 | |
for c in clusters: | |
p += len(c) - 1 | |
tp += len(c) | |
linked = set() | |
for m in c: | |
if m in mention_to_gold: | |
linked.add(mention_to_gold[m]) | |
else: | |
tp -= 1 | |
tp -= len(linked) | |
return tp, p | |
def phi4_simplified(c1, c2): | |
return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2)) | |
def ceafe_simplified(clusters, gold_clusters): | |
# lazy import to not force scipy installation | |
from scipy.optimize import linear_sum_assignment as linear_assignment | |
clusters = [c for c in clusters if len(c) != 1] | |
scores = np.zeros((len(gold_clusters), len(clusters))) | |
for i in range(len(gold_clusters)): | |
for j in range(len(clusters)): | |
scores[i, j] = phi4_simplified(gold_clusters[i], clusters[j]) | |
matching = linear_assignment(-scores) | |
matching = np.transpose(np.asarray(matching)) | |
similarity = sum(scores[matching[:, 0], matching[:, 1]]) | |
return similarity, len(clusters), similarity, len(gold_clusters) | |
def lea_simplified(clusters, mention_to_gold): | |
num, dem = 0, 0 | |
for c in clusters: | |
if len(c) == 1: | |
continue | |
common_links = 0 | |
all_links = len(c) * (len(c) - 1) / 2.0 | |
for i, m in enumerate(c): | |
if m in mention_to_gold: | |
for m2 in c[i + 1 :]: | |
if m2 in mention_to_gold and mention_to_gold[m] == mention_to_gold[m2]: | |
common_links += 1 | |
num += len(c) * common_links / float(all_links) | |
dem += len(c) | |
return num, dem | |
H = TypeVar("H", bound=Hashable) | |
class CorefHoiF1(DocumentMetric): | |
""" | |
Coreference evaluation based on official coref-hoi evaluation script, i.e., | |
https://github.com/lxucs/coref-hoi/blob/5ddfc3b64a5519c3555b5a57e47ab2f03c104a60/metrics.py. | |
The metric expects documents with a relation layer that contains binary relations | |
between mentions from the same coreference cluster. Works with relations targeting | |
mentions from multiple layers (e.g., cross-textual relations). | |
Args: | |
relation_layer: The name of the relation layer that contains the link relations. | |
include_singletons: If True (default), singletons will be included in the evaluation. | |
link_relation_label: If provided, only the relations with this label will be used | |
to create the clusters. | |
link_relation_relation_score_threshold: If provided, only the relations with a score | |
greater than or equal to this threshold will be used to create the clusters. | |
""" | |
def __init__( | |
self, | |
relation_layer: str, | |
include_singletons: bool = True, | |
link_relation_label: Optional[str] = None, | |
link_relation_relation_score_threshold: Optional[float] = None, | |
) -> None: | |
super().__init__() | |
self.relation_layer = relation_layer | |
self.link_relation_label = link_relation_label | |
self.include_singletons = include_singletons | |
self.link_relation_relation_score_threshold = link_relation_relation_score_threshold | |
def reset(self) -> None: | |
self.evaluators = [ | |
CorefHoiEvaluator(m) for m in (muc_simplified, b_cubed_simplified, ceafe_simplified) | |
] | |
def prepare_clusters_with_mapping( | |
self, mentions: Sequence[Annotation], relations: Sequence[BinaryRelation] | |
) -> Tuple[List[List[Annotation]], Dict[Annotation, Tuple[Annotation]]]: | |
# get connected components based on binary relations | |
connected_components = get_connected_components( | |
elements=mentions, | |
relations=relations, | |
link_relation_label=self.link_relation_label, | |
link_relation_relation_score_threshold=self.link_relation_relation_score_threshold, | |
add_singletons=self.include_singletons, | |
) | |
# store all clustered mentions in a list and | |
# create a map from each mention to its cluster | |
# (i.e. to the list of spans that includes all other mentions from the same cluster) | |
clusters = [] | |
mention_to_cluster = dict() | |
for cluster in connected_components: | |
clusters.append(cluster) | |
for mention in cluster: | |
mention_to_cluster[mention] = tuple(cluster) | |
return clusters, mention_to_cluster | |
def _update(self, doc: Document) -> None: | |
relation_layer = doc[self.relation_layer] | |
gold_mentions = [] | |
predicted_mentions = [] | |
for mention_layer in relation_layer.target_layers.values(): | |
gold_mentions.extend(mention_layer) | |
predicted_mentions.extend(mention_layer.predictions) | |
# prepare the clusters and mention-to-cluster mapping needed for evaluation | |
predicted_clusters, mention_to_predicted = self.prepare_clusters_with_mapping( | |
mentions=predicted_mentions, relations=relation_layer.predictions | |
) | |
gold_clusters, mention_to_gold = self.prepare_clusters_with_mapping( | |
mentions=gold_mentions, relations=relation_layer | |
) | |
for e in self.evaluators: | |
e.update(predicted_clusters, gold_clusters, mention_to_predicted, mention_to_gold) | |
def get_f1(self) -> float: | |
return sum(e.get_f1() for e in self.evaluators) / len(self.evaluators) | |
def get_recall(self) -> float: | |
return sum(e.get_recall() for e in self.evaluators) / len(self.evaluators) | |
def get_precision(self) -> float: | |
return sum(e.get_precision() for e in self.evaluators) / len(self.evaluators) | |
def get_prf(self) -> Tuple[float, float, float]: | |
return self.get_precision(), self.get_recall(), self.get_f1() | |
def _compute(self) -> float: | |
return self.get_f1() | |