import logging import math from typing import Any, Callable, Dict, List, Optional, Union, overload import numpy as np from pandas import MultiIndex from pie_modules.utils import flatten_dict from pytorch_ie import Document, DocumentMetric from pytorch_ie.core.metric import T from pytorch_ie.utils.hydra import resolve_target from src.hydra_callbacks.save_job_return_value import to_py_obj logger = logging.getLogger(__name__) def get_num_total(targets: List[int], preds: List[float]): return len(targets) def get_num_positives(targets: List[int], preds: List[float], positive_idx: int = 1): return len([v for v in targets if v == positive_idx]) @overload def discretize(values: List[float], threshold: float) -> List[float]: ... @overload def discretize(values: List[float], threshold: List[float]) -> Dict[Any, List[float]]: ... def discretize( values: List[float], threshold: Union[float, List[float], dict] ) -> Union[List[float], Dict[Any, List[float]]]: if isinstance(threshold, float): result = (np.array(values) >= threshold).astype(int).tolist() return result if isinstance(threshold, list): return {t: discretize(values=values, threshold=t) for t in threshold} # type: ignore if isinstance(threshold, dict): thresholds = ( np.arange(threshold["start"], threshold["end"], threshold["step"]).round(4).tolist() ) return discretize(values, threshold=thresholds) raise TypeError(f"threshold has unknown type: {threshold}") def get_metric_func(name: str) -> Callable: if name.endswith("_curve"): from sklearn.metrics import auc base_func = resolve_target(name) def wrapper(targets: List[int], preds: List[float], **kwargs): x, y, thresholds = base_func(targets, preds, **kwargs) return auc(y, x) return wrapper else: return resolve_target(name) def bootstrap( metric_fn: Callable[[List[int], Union[List[int], List[float]]], float], targets: List[int], predictions: Union[List[int], List[float]], n: int = 1_000, random_state: int | None = None, alpha: float = 0.95, ) -> Dict[str, float]: """ Returns mean and a two–sided (1–alpha) bootstrap CI for any pair-wise classification or ranking metric. Parameters ---------- metric_fn Metric function taking (targets, prediction) lists. targets Ground-truth 0/1 labels. prediction Scores or hard predictions (same length as `targets`). n Number of bootstrap replicates (after skipping degenerate ones). random_state Seed for reproducibility. alpha Confidence level (default 0.95 → 95 % CI). Notes ----- * A replicate that contains only one class is discarded because many sklearn metrics are undefined in that case. * If all replicates are discarded an exception is raised. """ y = np.asarray(targets) yhat = np.asarray(predictions) if y.shape[0] != yhat.shape[0]: raise ValueError("`targets` and `prediction` must have the same length") rng = np.random.default_rng(random_state) idx = np.arange(y.shape[0]) vals_list: list[float] = [] while len(vals_list) < n: sample_idx = rng.choice(idx, size=idx.shape[0], replace=True) y_samp, yhat_samp = y[sample_idx], yhat[sample_idx] # skip all-positive or all-negative bootstrap samples if y_samp.min() == y_samp.max(): continue vals_list.append(metric_fn(y_samp.tolist(), yhat_samp.tolist())) if not vals_list: raise RuntimeError("No valid bootstrap replicate contained both classes.") vals = np.asarray(vals_list, dtype=float) lower = np.percentile(vals, (1 - alpha) / 2 * 100) upper = np.percentile(vals, (1 + alpha) / 2 * 100) return {"mean": float(vals.mean()), "low": float(lower), "high": float(upper)} class BinaryClassificationMetricsSKLearn(DocumentMetric): def __init__( self, metrics: Dict[str, str], layer: str, label: Optional[str] = None, thresholds: Optional[Dict[str, float]] = None, default_target_idx: int = 0, default_prediction_score: float = 0.0, show_as_markdown: bool = False, markdown_precision: int = 4, bootstrap: Optional[list[str]] = None, bootstrap_n: int = 1_000, bootstrap_random_state: int | None = None, bootstrap_alpha: float = 0.95, create_plots: bool = True, plots: Optional[Dict[str, str]] = None, ): self.metrics = {name: get_metric_func(metric) for name, metric in metrics.items()} self.thresholds = thresholds or {} thresholds_not_in_metrics = { name: t for name, t in self.thresholds.items() if name not in self.metrics } if len(thresholds_not_in_metrics) > 0: logger.warning( f"there are discretizing thresholds that do not have a metric: {thresholds_not_in_metrics}" ) self.annotation_layer_name = layer self.annotation_label = label self.default_target_idx = default_target_idx self.default_prediction_score = default_prediction_score self.show_as_markdown = show_as_markdown self.markdown_precision = markdown_precision if create_plots: self.plots = { name: resolve_target(plot_func) for name, plot_func in (plots or {}).items() } else: self.plots = {} self.bootstrap = set(bootstrap or []) self.bootstrap_kwargs = { "n": bootstrap_n, "random_state": bootstrap_random_state, "alpha": bootstrap_alpha, } super().__init__() def reset(self) -> None: self._preds: List[float] = [] self._targets: List[int] = [] def _update(self, document: Document) -> None: annotation_layer = document[self.annotation_layer_name] target2idx = { ann: int(ann.score) for ann in annotation_layer if self.annotation_label is None or ann.label == self.annotation_label } prediction2score = { ann: ann.score for ann in annotation_layer.predictions if self.annotation_label is None or ann.label == self.annotation_label } all_args = set(target2idx) | set(prediction2score) all_targets: List[int] = [] all_predictions: List[float] = [] for args in all_args: target_idx = target2idx.get(args, self.default_target_idx) prediction_score = prediction2score.get(args, self.default_prediction_score) all_targets.append(target_idx) all_predictions.append(prediction_score) self._preds.extend(all_predictions) self._targets.extend(all_targets) def create_plots(self): from matplotlib import pyplot as plt # Get the number of metrics num_plots = len(self.plots) # Calculate rows and columns for subplots (aim for a square-like layout) ncols = math.ceil(math.sqrt(num_plots)) nrows = math.ceil(num_plots / ncols) # Create the subplots fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 10)) # Flatten the ax_list if necessary (in case of multiple rows/columns) if num_plots > 1: ax_list = ax_list.flatten().tolist() # Ensure it's a list, and flatten it if necessary else: ax_list = [ax_list] # Create each plot for ax, (name, plot_func) in zip(ax_list, self.plots.items()): # Set the title for each subplot ax.set_title(name) plot_func(y_true=self._targets, y_pred=self._preds, ax=ax) # Adjust layout to avoid overlapping plots plt.tight_layout() plt.show() def _compute(self) -> T: if len(self.plots) > 0: self.create_plots() result = {} for name, metric in self.metrics.items(): if name in self.thresholds: preds_dict = discretize(values=self._preds, threshold=self.thresholds[name]) if isinstance(preds_dict, dict): metric_results = { t: metric(self._targets, t_preds) for t, t_preds in preds_dict.items() } # just get the max max_t, max_v = max(metric_results.items(), key=lambda k_v: k_v[1]) result[f"{name}_threshold"] = max_t preds = discretize(values=self._preds, threshold=max_t) else: preds = preds_dict else: preds = self._preds if name in self.bootstrap: # bootstrap the metric result[name] = bootstrap( metric_fn=metric, targets=self._targets, predictions=preds, **self.bootstrap_kwargs, # type: ignore ) else: result[name] = metric(self._targets, preds) result = to_py_obj(result) if self.show_as_markdown: import pandas as pd result_flat = flatten_dict(result) series = pd.Series(result_flat) if isinstance(series.index, MultiIndex): if len(series.index.levels) > 1: # in fact, this is not a series anymore series = series.unstack(-1) else: series.index = series.index.get_level_values(0) logger.info( f"{self.current_split}\n{series.round(self.markdown_precision).to_markdown()}" ) return result