update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
import logging | |
import warnings | |
from collections import defaultdict | |
from functools import partial | |
from typing import Callable, Iterable, List, Optional, Set, Tuple | |
import numpy as np | |
import pandas as pd | |
from pytorch_ie import DocumentMetric | |
from pytorch_ie.annotations import BinaryRelation | |
from sklearn.metrics import average_precision_score, ndcg_score | |
logger = logging.getLogger(__name__) | |
NEG_INF = -1e9 # smaller than any real score | |
# metrics | |
def true_mrr(y_true: np.ndarray, y_score: np.ndarray, k: int | None = None) -> float: | |
""" | |
Macro MRR over *all* queries. | |
β’ Reciprocal rank is 0 when a query has no relevant item. | |
β’ If k is given, restrict the search to the top-k list. | |
""" | |
if y_true.size == 0: | |
return np.nan | |
rr = [] | |
for t, s in zip(y_true, y_score): | |
if t.sum() == 0: | |
rr.append(0.0) | |
continue | |
order = np.argsort(-s) | |
if k is not None: | |
order = order[:k] | |
# first position where t == 1, +1 for 1-based rank | |
first_hit = np.flatnonzero(t[order] > 0) | |
rank = first_hit[0] + 1 if first_hit.size else np.inf | |
rr.append(0.0 if np.isinf(rank) else 1.0 / rank) | |
return np.mean(rr) | |
def macro_ndcg(y_true: np.ndarray, y_score: np.ndarray, k: int | None = None) -> float: | |
""" | |
Macro NDCG@k over all queries. | |
ndcg_score returns 0 when a query has no positives, so no masking is required. | |
""" | |
if y_true.size == 0: | |
return np.nan | |
return ndcg_score(y_true, y_score, k=k) | |
def macro_map(y_true: np.ndarray, y_score: np.ndarray) -> float: | |
""" | |
Macro MAP: mean of Average-Precision per query. | |
Queries without positives contribute AP = 0. | |
""" | |
if y_true.size == 0: | |
return np.nan | |
ap = [] | |
for t, s in zip(y_true, y_score): | |
if t.sum() == 0: | |
ap.append(0.0) | |
else: | |
ap.append(average_precision_score(t, s)) | |
return np.mean(ap) | |
def ap_micro(y_true: np.ndarray, y_score: np.ndarray) -> float: | |
""" | |
Micro AP over the entire pool (unchanged). | |
""" | |
with warnings.catch_warnings(): | |
warnings.filterwarnings("ignore", message="No positive class found in y_true") | |
return average_precision_score(y_true.ravel(), y_score.ravel()) | |
# --------------------------- | |
# Recall@k | |
# --------------------------- | |
def recall_at_k_micro(y_true: np.ndarray, y_score: np.ndarray, k: int = 5) -> float: | |
""" | |
Micro Recall@k (a.k.a. instance-level recall) | |
β Each *positive instance* counts once, regardless of which query it belongs to. | |
β Denominator = total #positives across the whole pool. | |
""" | |
total_pos = y_true.sum() | |
if total_pos == 0: | |
return np.nan | |
topk = np.argsort(-y_score, axis=1)[:, :k] # indices of top-k per query | |
rows = np.arange(topk.shape[0])[:, None] | |
hits = (y_true[rows, topk] > 0).sum() # total #hits (instances) | |
return hits / total_pos | |
def recall_at_k_macro(y_true: np.ndarray, y_score: np.ndarray, k: int = 5) -> float: | |
""" | |
Macro Recall@k (query-level recall) | |
β First compute recall per *query* (#hits / #positives in that query). | |
β Then average across all queries that actually contain β₯1 positive. | |
""" | |
mask = y_true.sum(axis=1) > 0 # keep only valid queries | |
if not mask.any(): | |
return np.nan | |
Yt, Ys = y_true[mask], y_score[mask] | |
topk = np.argsort(-Ys, axis=1)[:, :k] | |
rows = np.arange(Yt.shape[0])[:, None] | |
hits_per_q = (Yt[rows, topk] > 0).sum(axis=1) # shape: (n_queries,) | |
pos_per_q = Yt.sum(axis=1) | |
return np.mean(hits_per_q / pos_per_q) # average of query recalls | |
# --------------------------- | |
# Precision@k | |
# --------------------------- | |
def precision_at_k_micro(y_true: np.ndarray, y_score: np.ndarray, k: int = 5) -> float: | |
""" | |
Micro Precision@k (pool-level precision) | |
β Numerator = total #hits across all queries. | |
β Denominator = total #predictions considered (n_queries Β· k). | |
""" | |
if y_true.size == 0: | |
return np.nan | |
topk = np.argsort(-y_score, axis=1)[:, :k] | |
rows = np.arange(topk.shape[0])[:, None] | |
hits = (y_true[rows, topk] > 0).sum() | |
total_pred = y_true.shape[0] * k | |
return hits / total_pred | |
def precision_at_k_macro(y_true: np.ndarray, y_score: np.ndarray, k: int = 5) -> float: | |
""" | |
Macro Precision@k (query-level precision) | |
β Compute precision = (#hits / k) for each query, **including those with zero positives**, | |
then average. | |
""" | |
if y_true.size == 0: | |
return np.nan | |
topk = np.argsort(-y_score, axis=1)[:, :k] | |
rows = np.arange(topk.shape[0])[:, None] | |
rel = y_true[rows, topk] > 0 # shape: (n_queries, k) | |
precision_per_q = rel.mean(axis=1) # mean over k positions | |
return precision_per_q.mean() | |
# helper methods | |
def bootstrap( | |
metric_fn: Callable[[np.ndarray, np.ndarray], float], | |
y_true: np.ndarray, | |
y_score: np.ndarray, | |
n: int = 1000, | |
rng=None, | |
) -> dict[str, float]: | |
rng = np.random.default_rng(rng) | |
idx = np.arange(len(y_true)) | |
vals: list[float] = [] | |
while len(vals) < n: | |
sample = rng.choice(idx, size=len(idx), replace=True) | |
t = y_true[sample] | |
s = y_score[sample] | |
if t.sum() == 0: # no positive at all β resample | |
continue | |
vals.append(metric_fn(t, s)) | |
result = np.asarray(vals) | |
# get 95% confidence interval | |
lo, hi = np.percentile(result, [2.5, 97.5]) | |
return {"mean": result.mean(), "low": lo, "high": hi} | |
def evaluate_with_ranx( | |
pred_rels: set[BinaryRelation], | |
target_rels: set[BinaryRelation], | |
metrics: list[str], | |
include_queries_without_gold: bool = True, | |
) -> dict[str, float]: | |
# lazy import to not require ranx via requirements.txt | |
import ranx | |
all_rels = set(pred_rels) | set(target_rels) | |
all_heads = {rel.head for rel in all_rels} | |
head2id = {head: f"q_{idx}" for idx, head in enumerate(sorted(all_heads))} | |
tail_and_label2id = {(ann.tail, ann.label): f"d_{idx}" for idx, ann in enumerate(all_rels)} | |
qrels_dict: dict[str, dict[str, int]] = defaultdict(dict) # {query_id: {doc_id: 1}} | |
run_dict: dict[str, dict[str, float]] = defaultdict(dict) # {query_id: {doc_id: score}} | |
for target_rel in target_rels: | |
query_id = head2id[target_rel.head] | |
doc_id = tail_and_label2id[(target_rel.tail, target_rel.label)] | |
if target_rel.score != 1.0: | |
raise ValueError( | |
f"target score must be 1.0, but got {target_rel.score} for {target_rel}" | |
) | |
qrels_dict[query_id][doc_id] = 1 | |
for pred_rel in pred_rels: | |
query_id = head2id[pred_rel.head] | |
doc_id = tail_and_label2id[(pred_rel.tail, pred_rel.label)] | |
run_dict[query_id][doc_id] = pred_rel.score | |
if include_queries_without_gold: | |
# add missing query ids to rund_dict and qrels_dict | |
for query_id in set(head2id.values()) - set(qrels_dict): | |
qrels_dict[query_id] = {} | |
# evaluate | |
qrels = ranx.Qrels(qrels_dict) | |
run = ranx.Run(run_dict) | |
results = ranx.evaluate(qrels, run, metrics, make_comparable=True) | |
return results | |
def deduplicate_relations( | |
relations: Iterable[BinaryRelation], caption: str | |
) -> Set[BinaryRelation]: | |
pred2scores = defaultdict(set) | |
for ann in relations: | |
pred2scores[ann].add(round(ann.score, 4)) | |
# warning for duplicates | |
preds_with_duplicates = [ann for ann, scores in pred2scores.items() if len(scores) > 1] | |
if len(preds_with_duplicates) > 0: | |
logger.warning( | |
f"there are {len(preds_with_duplicates)} {caption} with duplicates: " | |
f"{preds_with_duplicates}. We will take the max score for each annotation." | |
) | |
# take the max score for each annotation | |
result = {ann.copy(score=max(scores)) for ann, scores in pred2scores.items()} | |
return result | |
def construct_y_true_and_score( | |
preds: Iterable[BinaryRelation], targets: Iterable[BinaryRelation] | |
) -> Tuple[np.ndarray, np.ndarray]: | |
# helper constructs | |
all_anns = set(preds) | set(targets) | |
head2relations = defaultdict(list) | |
for ann in all_anns: | |
head2relations[ann.head].append(ann) | |
target2score = {rel: rel.score for rel in targets} | |
pred2score = {rel: rel.score for rel in preds} | |
max_len = max(len(relations) for relations in head2relations.values()) | |
target_rows, pred_rows = [], [] | |
for query in head2relations: | |
relations = head2relations[query] | |
# get a very small, random score for missing predictions. Or should we use 0.0 as before? or NEG_INF? | |
missing_pred_score = NEG_INF # np.random.uniform(0.0, 0.001) #0.0 # | |
missing_target_score = 0 | |
query_scores = [ | |
(target2score.get(ann, missing_target_score), pred2score.get(ann, missing_pred_score)) | |
for ann in relations | |
] | |
# sort by descending order of prediction score | |
query_scores_sorted = np.array(sorted(query_scores, key=lambda x: x[1], reverse=True)) | |
# pad with zeros so every row has the same length | |
pad_width = max_len - len(query_scores) | |
query_target = np.pad( | |
query_scores_sorted[:, 0], (0, pad_width), constant_values=missing_target_score | |
) | |
query_pred = np.pad( | |
query_scores_sorted[:, 1], (0, pad_width), constant_values=missing_pred_score | |
) | |
target_rows.append(query_target) | |
pred_rows.append(query_pred) | |
y_true = np.vstack(target_rows) # shape (n_queries, max_len) | |
y_score = np.vstack(pred_rows) | |
return y_true, y_score | |
class SemanticallySameRankingMetric(DocumentMetric): | |
def __init__( | |
self, | |
layer: str, | |
label: Optional[str] = None, | |
add_reversed: bool = False, | |
require_positive_gold: bool = False, | |
bootstrap_n: Optional[int] = None, | |
k_values: Optional[List[int]] = None, | |
return_coverage: bool = True, | |
show_as_markdown: bool = False, | |
use_ranx: bool = False, | |
add_stats_to_result: bool = False, | |
) -> None: | |
super().__init__() | |
self.layer = layer | |
self.label = label | |
self.add_reversed = add_reversed | |
self.require_positive_gold = require_positive_gold | |
self.bootstrap_n = bootstrap_n | |
self.k_values = k_values if k_values is not None else [1, 5, 10] | |
self.return_coverage = return_coverage | |
self.show_as_markdown = show_as_markdown | |
self.use_ranx = use_ranx | |
self.add_stats_to_result = add_stats_to_result | |
self.metrics = { | |
"macro_ndcg": macro_ndcg, | |
"macro_mrr": true_mrr, | |
"macro_map": macro_map, | |
"micro_ap": ap_micro, | |
} | |
for name, func in [ | |
("macro_ndcg", macro_ndcg), | |
("micro_recall", recall_at_k_micro), | |
("micro_precision", precision_at_k_micro), | |
("macro_recall", recall_at_k_macro), | |
("macro_precision", precision_at_k_macro), | |
]: | |
for k in self.k_values: | |
self.metrics[f"{name}@{k}"] = partial(func, k=k) # type: ignore | |
self.ranx_metrics = ["map", "mrr", "ndcg"] | |
for name in ["recall", "precision", "ndcg"]: | |
for k in self.k_values: | |
self.ranx_metrics.append(f"{name}@{k}") | |
def reset(self) -> None: | |
""" | |
Reset the metric to its initial state. | |
""" | |
self._preds: List[BinaryRelation] = [] | |
self._targets: List[BinaryRelation] = [] | |
def _update(self, document): | |
layer = document[self.layer] | |
ann: BinaryRelation | |
for ann in layer: | |
if self.label is None or ann.label == self.label: | |
if ann.score > 0.0: | |
self._targets.append(ann.copy()) | |
if self.add_reversed: | |
self._targets.append(ann.copy(head=ann.tail, tail=ann.head)) | |
for ann in layer.predictions: | |
if self.label is None or ann.label == self.label: | |
if ann.score > 0.0: | |
self._preds.append(ann.copy()) | |
if self.add_reversed: | |
self._preds.append(ann.copy(head=ann.tail, tail=ann.head)) | |
def _compute(self): | |
# take the max score for each annotation | |
preds_deduplicated = deduplicate_relations(self._preds, "predictions") | |
targets_deduplicated = deduplicate_relations(self._targets, "targets") | |
stats = { | |
"gold": len(targets_deduplicated), | |
"preds": len(preds_deduplicated), | |
"queries": len( | |
set(ann.head for ann in targets_deduplicated) | |
| set(ann.head for ann in preds_deduplicated) | |
), | |
} | |
if self.use_ranx: | |
if self.bootstrap_n is not None: | |
raise ValueError( | |
"Ranx does not support bootstrapping. Please set bootstrap_n=None." | |
) | |
scores = evaluate_with_ranx( | |
preds_deduplicated, | |
targets_deduplicated, | |
metrics=self.ranx_metrics, | |
include_queries_without_gold=not self.require_positive_gold, | |
) | |
if self.add_stats_to_result: | |
scores.update(stats) | |
# logger.info(f"results via ranx:\n{pd.Series(ranx_result).sort_index().round(3).to_markdown()}") | |
df = pd.DataFrame.from_records([scores], index=["score"]) | |
else: | |
y_true, y_score = construct_y_true_and_score( | |
preds=preds_deduplicated, targets=targets_deduplicated | |
) | |
# original definition β share of queries with β₯1 positive | |
coverage = (y_true.sum(axis=1) > 0).mean() | |
# keep only queries that actually have at least one gold positive | |
if self.require_positive_gold: | |
mask = y_true.sum(axis=1) > 0 # shape: (n_queries,) | |
y_true = y_true[mask] | |
y_score = y_score[mask] | |
if self.bootstrap_n is not None: | |
scores = { | |
name: bootstrap(fn, y_true, y_score, n=self.bootstrap_n) | |
for name, fn in self.metrics.items() | |
} | |
if self.add_stats_to_result: | |
scores["stats"] = stats | |
df = pd.DataFrame(scores) | |
else: | |
scores = {name: fn(y_true, y_score) for name, fn in self.metrics.items()} | |
if self.add_stats_to_result: | |
scores.update(stats) | |
df = pd.DataFrame.from_records([scores], index=["score"]) | |
if self.return_coverage: | |
scores["coverage"] = coverage | |
if self.show_as_markdown: | |
if not self.add_stats_to_result: | |
logger.info( | |
logger.info( | |
f'\nstatistics ({self.layer}):\n{pd.Series(stats, name="value").to_markdown()}' | |
) | |
) | |
logger.info(f"\n{self.layer}:\n{df.round(4).T.to_markdown()}") | |
return scores | |