update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
import logging | |
from collections import defaultdict | |
from typing import Any, Dict, List, Optional, Tuple | |
import pandas as pd | |
from pytorch_ie import Document, DocumentMetric | |
logger = logging.getLogger() | |
class ScoreDistribution(DocumentMetric): | |
"""Computes the distribution of prediction scores for annotations in a layer. The scores are | |
separated into true positives (TP) and false positives (FP) based on the gold annotations. | |
Args: | |
layer: The name of the annotation layer to analyze. | |
per_label: If True, the scores are separated per label. Default is False. | |
label_field: The field name of the label to use for separating the scores per label. Default is "label". | |
equal_sample_size_binning: If True, the scores are binned into equal sample sizes. If False, | |
the scores are binned into equal width. The former is useful when the distribution of scores is skewed. | |
Default is True. | |
show_plot: If True, a plot of the score distribution is shown. Default is False. | |
plotting_backend: The plotting backend to use. Default is "plotly". | |
plotting_caption_mapping: A mapping to rename any caption entries for plotting, i.e., the layer name, | |
labels, or TP/FP. Default is None. | |
plotting_colors: A dictionary mapping from gold scores to colors for plotting. Default is None. | |
""" | |
def __init__( | |
self, | |
layer: str, | |
label_field: str = "label", | |
per_label: bool = False, | |
show_plot: bool = False, | |
equal_sample_size_binning: bool = True, | |
plotting_backend: str = "plotly", | |
plotting_caption_mapping: Optional[Dict[str, str]] = None, | |
plotting_colors: Optional[Dict[str, str]] = None, | |
plotly_use_create_distplot: bool = True, | |
plotly_barmode: Optional[str] = None, | |
plotly_marginal: Optional[str] = "violin", | |
plotly_font: Optional[Dict[str, Any]] = None, | |
plotly_font_size: Optional[int] = None, | |
plotly_font_family: Optional[str] = None, | |
plotly_background_color: Optional[str] = None, | |
): | |
super().__init__() | |
self.layer = layer | |
self.label_field = label_field | |
self.per_label = per_label | |
self.equal_sample_size_binning = equal_sample_size_binning | |
self.plotting_backend = plotting_backend | |
self.show_plot = show_plot | |
self.plotting_caption_mapping = plotting_caption_mapping or {} | |
self.plotting_colors = plotting_colors | |
self.plotly_use_create_distplot = plotly_use_create_distplot | |
self.plotly_barmode = plotly_barmode | |
self.plotly_marginal = plotly_marginal | |
self.plotly_font = plotly_font or {} | |
if plotly_font_size is not None: | |
logger.warning( | |
"Parameter 'plotly_font_size' is deprecated. Use 'plotly_font' with 'size' key instead." | |
) | |
self.plotly_font["size"] = plotly_font_size | |
self.plotly_font_family = plotly_font_family | |
self.plotly_background_color = plotly_background_color | |
self.scores: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list)) | |
def reset(self): | |
self.scores = defaultdict(lambda: defaultdict(list)) | |
def _update(self, document: Document): | |
gold_annotations = set(document[self.layer]) | |
for ann in document[self.layer].predictions: | |
if self.per_label: | |
label = getattr(ann, self.label_field) | |
else: | |
label = "ALL" | |
if ann in gold_annotations: | |
self.scores[label]["TP"].append(ann.score) | |
else: | |
self.scores[label]["FP"].append(ann.score) | |
def _combine_scores( | |
self, | |
scores_tp: List[float], | |
score_fp: List[float], | |
col_name_pred: str = "prediction", | |
col_name_gold: str = "gold", | |
) -> pd.DataFrame: | |
scores_tp_df = pd.DataFrame(scores_tp, columns=[col_name_pred]) | |
scores_tp_df[col_name_gold] = 1.0 | |
scores_fp_df = pd.DataFrame(score_fp, columns=[col_name_pred]) | |
scores_fp_df[col_name_gold] = 0.0 | |
scores_df = pd.concat([scores_tp_df, scores_fp_df]) | |
return scores_df | |
def _get_calibration_data_and_metrics( | |
self, scores: pd.DataFrame, q: int = 20 | |
) -> Tuple[pd.DataFrame, pd.Series]: | |
from sklearn.metrics import brier_score_loss | |
if self.equal_sample_size_binning: | |
# Create bins with equal number of samples. | |
scores["bin"] = pd.qcut(scores["prediction"], q=q, labels=False) | |
else: | |
# Create bins with equal width. | |
scores["bin"] = pd.cut( | |
scores["prediction"], | |
bins=q, | |
include_lowest=True, | |
right=True, | |
labels=False, | |
) | |
calibration_data = ( | |
scores.groupby("bin") | |
.apply( | |
lambda x: pd.Series( | |
{ | |
"avg_score": x["prediction"].mean(), | |
"fraction_positive": x["gold"].mean(), | |
"count": len(x), | |
} | |
) | |
) | |
.reset_index() | |
) | |
total_count = scores.shape[0] | |
calibration_data["bin_weight"] = calibration_data["count"] / total_count | |
# Calculate the absolute differences and squared differences. | |
calibration_data["abs_diff"] = abs( | |
calibration_data["avg_score"] - calibration_data["fraction_positive"] | |
) | |
calibration_data["squared_diff"] = ( | |
calibration_data["avg_score"] - calibration_data["fraction_positive"] | |
) ** 2 | |
# Compute Expected Calibration Error (ECE): weighted average of absolute differences. | |
ece = (calibration_data["abs_diff"] * calibration_data["bin_weight"]).sum() | |
# Compute Maximum Calibration Error (MCE): maximum absolute difference. | |
mce = calibration_data["abs_diff"].max() | |
# Compute Mean Squared Error (MSE): weighted average of squared differences. | |
mse = (calibration_data["squared_diff"] * calibration_data["bin_weight"]).sum() | |
# Compute the Brier Score on the raw predictions. | |
brier = brier_score_loss(scores["gold"], scores["prediction"]) | |
values = { | |
"ece": ece, | |
"mce": mce, | |
"mse": mse, | |
"brier": brier, | |
} | |
return calibration_data, pd.Series(values) | |
def calculate_calibration_metrics(self, scores_combined: pd.DataFrame) -> pd.DataFrame: | |
calibration_data_dict = {} | |
calibration_metrics_dict = {} | |
for label, current_scores in scores_combined.groupby("label"): | |
calibration_data, calibration_metrics = self._get_calibration_data_and_metrics( | |
current_scores, q=20 | |
) | |
calibration_data_dict[label] = calibration_data | |
calibration_metrics_dict[label] = calibration_metrics | |
all_calibration_data = pd.concat( | |
calibration_data_dict, names=["label", "idx"] | |
).reset_index(level=0) | |
all_calibration_metrics = pd.concat(calibration_metrics_dict, axis=1).T | |
if self.show_plot: | |
self.plot_calibration_data(calibration_data=all_calibration_data) | |
return all_calibration_metrics | |
def calculate_correlation(self, scores: pd.DataFrame) -> pd.Series: | |
result_dict = {} | |
for label, current_scores in scores.groupby("label"): | |
result_dict[label] = current_scores.drop("label", axis=1).corr()["prediction"]["gold"] | |
return pd.Series(result_dict, name="correlation") | |
def mapped_layer(self): | |
return self.plotting_caption_mapping.get(self.layer, self.layer) | |
def plot_score_distribution(self, scores: pd.DataFrame): | |
if self.plotting_backend == "plotly": | |
for label in scores["label"].unique(): | |
description = f"Distribution of Predicted Scores for {self.mapped_layer}" | |
if self.per_label: | |
label_mapped = self.plotting_caption_mapping.get(label, label) | |
description += f" ({label_mapped})" | |
if self.plotly_use_create_distplot: | |
import plotly.figure_factory as ff | |
current_scores = scores[scores["label"] == label] | |
# group by gold score | |
scores_dict = ( | |
current_scores.groupby("gold")["prediction"].apply(list).to_dict() | |
) | |
group_labels, hist_data = zip(*scores_dict.items()) | |
group_labels_renamed = [ | |
self.plotting_caption_mapping.get(label, label) for label in group_labels | |
] | |
if self.plotting_colors is not None: | |
colors = [ | |
self.plotting_colors[group_label] for group_label in group_labels | |
] | |
else: | |
colors = None | |
fig = ff.create_distplot( | |
hist_data, | |
group_labels=group_labels_renamed, | |
show_hist=True, | |
colors=colors, | |
bin_size=0.025, | |
) | |
else: | |
import plotly.express as px | |
fig = px.histogram( | |
scores, | |
x="prediction", | |
color="gold", | |
marginal=self.plotly_marginal, # "violin", # or box, violin, rug | |
hover_data=scores.columns, | |
color_discrete_map=self.plotting_colors, | |
nbins=50, | |
) | |
fig.update_layout( | |
height=600, | |
width=800, | |
title_text=description, | |
title_x=0.5, | |
font=self.plotly_font, | |
legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01), | |
) | |
if self.plotly_barmode is not None: | |
fig.update_layout(barmode=self.plotly_barmode) | |
if self.plotly_font_family is not None: | |
fig.update_layout(font_family=self.plotly_font_family) | |
if self.plotly_background_color is not None: | |
fig.update_layout( | |
plot_bgcolor=self.plotly_background_color, | |
paper_bgcolor=self.plotly_background_color, | |
) | |
fig.show() | |
else: | |
raise NotImplementedError(f"Plotting backend {self.plotting_backend} not implemented") | |
def plot_calibration_data(self, calibration_data: pd.DataFrame): | |
import plotly.express as px | |
import plotly.graph_objects as go | |
color = "label" if self.per_label else None | |
x_col = "avg_score" | |
y_col = "fraction_positive" | |
fig = px.scatter( | |
calibration_data, | |
x=x_col, | |
y=y_col, | |
color=color, | |
trendline="ols", | |
labels=self.plotting_caption_mapping, | |
) | |
if not self.per_label: | |
fig["data"][1]["name"] = "prediction vs. gold" | |
# show legend only for trendlines | |
for idx, trace_data in enumerate(fig["data"]): | |
if idx % 2 == 0: | |
trace_data["showlegend"] = False | |
else: | |
trace_data["showlegend"] = True | |
# add the optimal line | |
minimum = calibration_data[x_col].min() | |
maximum = calibration_data[x_col].max() | |
fig.add_trace( | |
go.Scatter( | |
x=[minimum, maximum], | |
y=[minimum, maximum], | |
mode="lines", | |
name="optimal", | |
line=dict(color="black", dash="dash"), | |
) | |
) | |
fig.update_layout( | |
height=600, | |
width=800, | |
title_text=f"Mean Binned Scores for {self.mapped_layer}", | |
title_x=0.5, | |
font=self.plotly_font, | |
) | |
fig.update_layout( | |
legend=dict( | |
yanchor="top", | |
y=0.99, | |
xanchor="left", | |
x=0.01, | |
title="OLS trendline" + ("s" if self.per_label else ""), | |
), | |
) | |
if self.plotly_background_color is not None: | |
fig.update_layout( | |
plot_bgcolor=self.plotly_background_color, | |
paper_bgcolor=self.plotly_background_color, | |
) | |
if self.plotly_font_family is not None: | |
fig.update_layout(font_family=self.plotly_font_family) | |
fig.show() | |
def _compute(self) -> Dict[str, Dict[str, Any]]: | |
scores_combined = pd.concat( | |
{ | |
label: self._combine_scores(scores["TP"], scores["FP"]) | |
for label, scores in self.scores.items() | |
}, | |
names=["label", "idx"], | |
).reset_index(level=0) | |
result_df = scores_combined.groupby("label")["prediction"].agg(["mean", "std", "count"]) | |
if self.show_plot: | |
self.plot_score_distribution(scores=scores_combined) | |
calibration_metrics = self.calculate_calibration_metrics(scores_combined) | |
calibration_metrics["correlation"] = self.calculate_correlation(scores_combined) | |
result_df = pd.concat( | |
{"prediction": result_df, "prediction vs. gold": calibration_metrics}, axis=1 | |
) | |
if not self.per_label: | |
result = result_df.xs("ALL") | |
else: | |
result = result_df.T.stack().unstack() | |
result_dict = { | |
main_key: result.xs(main_key).T.to_dict() | |
for main_key in result.index.get_level_values(0).unique() | |
} | |
return result_dict | |