ArneBinder's picture
update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e verified
import logging
import math
from typing import Any, Callable, Dict, List, Optional, Union, overload
import numpy as np
from pandas import MultiIndex
from pie_modules.utils import flatten_dict
from pytorch_ie import Document, DocumentMetric
from pytorch_ie.core.metric import T
from pytorch_ie.utils.hydra import resolve_target
from src.hydra_callbacks.save_job_return_value import to_py_obj
logger = logging.getLogger(__name__)
def get_num_total(targets: List[int], preds: List[float]):
return len(targets)
def get_num_positives(targets: List[int], preds: List[float], positive_idx: int = 1):
return len([v for v in targets if v == positive_idx])
@overload
def discretize(values: List[float], threshold: float) -> List[float]: ...
@overload
def discretize(values: List[float], threshold: List[float]) -> Dict[Any, List[float]]: ...
def discretize(
values: List[float], threshold: Union[float, List[float], dict]
) -> Union[List[float], Dict[Any, List[float]]]:
if isinstance(threshold, float):
result = (np.array(values) >= threshold).astype(int).tolist()
return result
if isinstance(threshold, list):
return {t: discretize(values=values, threshold=t) for t in threshold} # type: ignore
if isinstance(threshold, dict):
thresholds = (
np.arange(threshold["start"], threshold["end"], threshold["step"]).round(4).tolist()
)
return discretize(values, threshold=thresholds)
raise TypeError(f"threshold has unknown type: {threshold}")
def get_metric_func(name: str) -> Callable:
if name.endswith("_curve"):
from sklearn.metrics import auc
base_func = resolve_target(name)
def wrapper(targets: List[int], preds: List[float], **kwargs):
x, y, thresholds = base_func(targets, preds, **kwargs)
return auc(y, x)
return wrapper
else:
return resolve_target(name)
def bootstrap(
metric_fn: Callable[[List[int], Union[List[int], List[float]]], float],
targets: List[int],
predictions: Union[List[int], List[float]],
n: int = 1_000,
random_state: int | None = None,
alpha: float = 0.95,
) -> Dict[str, float]:
"""
Returns mean and a two–sided (1–alpha) bootstrap CI for any
pair-wise classification or ranking metric.
Parameters
----------
metric_fn Metric function taking (targets, prediction) lists.
targets Ground-truth 0/1 labels.
prediction Scores or hard predictions (same length as `targets`).
n Number of bootstrap replicates (after skipping degenerate ones).
random_state Seed for reproducibility.
alpha Confidence level (default 0.95 → 95 % CI).
Notes
-----
* A replicate that contains only one class is discarded
because many sklearn metrics are undefined in that case.
* If all replicates are discarded an exception is raised.
"""
y = np.asarray(targets)
yhat = np.asarray(predictions)
if y.shape[0] != yhat.shape[0]:
raise ValueError("`targets` and `prediction` must have the same length")
rng = np.random.default_rng(random_state)
idx = np.arange(y.shape[0])
vals_list: list[float] = []
while len(vals_list) < n:
sample_idx = rng.choice(idx, size=idx.shape[0], replace=True)
y_samp, yhat_samp = y[sample_idx], yhat[sample_idx]
# skip all-positive or all-negative bootstrap samples
if y_samp.min() == y_samp.max():
continue
vals_list.append(metric_fn(y_samp.tolist(), yhat_samp.tolist()))
if not vals_list:
raise RuntimeError("No valid bootstrap replicate contained both classes.")
vals = np.asarray(vals_list, dtype=float)
lower = np.percentile(vals, (1 - alpha) / 2 * 100)
upper = np.percentile(vals, (1 + alpha) / 2 * 100)
return {"mean": float(vals.mean()), "low": float(lower), "high": float(upper)}
class BinaryClassificationMetricsSKLearn(DocumentMetric):
def __init__(
self,
metrics: Dict[str, str],
layer: str,
label: Optional[str] = None,
thresholds: Optional[Dict[str, float]] = None,
default_target_idx: int = 0,
default_prediction_score: float = 0.0,
show_as_markdown: bool = False,
markdown_precision: int = 4,
bootstrap: Optional[list[str]] = None,
bootstrap_n: int = 1_000,
bootstrap_random_state: int | None = None,
bootstrap_alpha: float = 0.95,
create_plots: bool = True,
plots: Optional[Dict[str, str]] = None,
):
self.metrics = {name: get_metric_func(metric) for name, metric in metrics.items()}
self.thresholds = thresholds or {}
thresholds_not_in_metrics = {
name: t for name, t in self.thresholds.items() if name not in self.metrics
}
if len(thresholds_not_in_metrics) > 0:
logger.warning(
f"there are discretizing thresholds that do not have a metric: {thresholds_not_in_metrics}"
)
self.annotation_layer_name = layer
self.annotation_label = label
self.default_target_idx = default_target_idx
self.default_prediction_score = default_prediction_score
self.show_as_markdown = show_as_markdown
self.markdown_precision = markdown_precision
if create_plots:
self.plots = {
name: resolve_target(plot_func) for name, plot_func in (plots or {}).items()
}
else:
self.plots = {}
self.bootstrap = set(bootstrap or [])
self.bootstrap_kwargs = {
"n": bootstrap_n,
"random_state": bootstrap_random_state,
"alpha": bootstrap_alpha,
}
super().__init__()
def reset(self) -> None:
self._preds: List[float] = []
self._targets: List[int] = []
def _update(self, document: Document) -> None:
annotation_layer = document[self.annotation_layer_name]
target2idx = {
ann: int(ann.score)
for ann in annotation_layer
if self.annotation_label is None or ann.label == self.annotation_label
}
prediction2score = {
ann: ann.score
for ann in annotation_layer.predictions
if self.annotation_label is None or ann.label == self.annotation_label
}
all_args = set(target2idx) | set(prediction2score)
all_targets: List[int] = []
all_predictions: List[float] = []
for args in all_args:
target_idx = target2idx.get(args, self.default_target_idx)
prediction_score = prediction2score.get(args, self.default_prediction_score)
all_targets.append(target_idx)
all_predictions.append(prediction_score)
self._preds.extend(all_predictions)
self._targets.extend(all_targets)
def create_plots(self):
from matplotlib import pyplot as plt
# Get the number of metrics
num_plots = len(self.plots)
# Calculate rows and columns for subplots (aim for a square-like layout)
ncols = math.ceil(math.sqrt(num_plots))
nrows = math.ceil(num_plots / ncols)
# Create the subplots
fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 10))
# Flatten the ax_list if necessary (in case of multiple rows/columns)
if num_plots > 1:
ax_list = ax_list.flatten().tolist() # Ensure it's a list, and flatten it if necessary
else:
ax_list = [ax_list]
# Create each plot
for ax, (name, plot_func) in zip(ax_list, self.plots.items()):
# Set the title for each subplot
ax.set_title(name)
plot_func(y_true=self._targets, y_pred=self._preds, ax=ax)
# Adjust layout to avoid overlapping plots
plt.tight_layout()
plt.show()
def _compute(self) -> T:
if len(self.plots) > 0:
self.create_plots()
result = {}
for name, metric in self.metrics.items():
if name in self.thresholds:
preds_dict = discretize(values=self._preds, threshold=self.thresholds[name])
if isinstance(preds_dict, dict):
metric_results = {
t: metric(self._targets, t_preds) for t, t_preds in preds_dict.items()
}
# just get the max
max_t, max_v = max(metric_results.items(), key=lambda k_v: k_v[1])
result[f"{name}_threshold"] = max_t
preds = discretize(values=self._preds, threshold=max_t)
else:
preds = preds_dict
else:
preds = self._preds
if name in self.bootstrap:
# bootstrap the metric
result[name] = bootstrap(
metric_fn=metric,
targets=self._targets,
predictions=preds,
**self.bootstrap_kwargs, # type: ignore
)
else:
result[name] = metric(self._targets, preds)
result = to_py_obj(result)
if self.show_as_markdown:
import pandas as pd
result_flat = flatten_dict(result)
series = pd.Series(result_flat)
if isinstance(series.index, MultiIndex):
if len(series.index.levels) > 1:
# in fact, this is not a series anymore
series = series.unstack(-1)
else:
series.index = series.index.get_level_values(0)
logger.info(
f"{self.current_split}\n{series.round(self.markdown_precision).to_markdown()}"
)
return result