File size: 9,874 Bytes
3133b5e d868d2e 3133b5e d868d2e 3133b5e d868d2e 3133b5e d868d2e 3133b5e d868d2e 3133b5e d868d2e 3133b5e d868d2e 3133b5e d868d2e 3133b5e d868d2e 3133b5e d868d2e 3133b5e d868d2e 3133b5e d868d2e 3133b5e d868d2e 3133b5e d868d2e 3133b5e d868d2e 3133b5e d868d2e 3133b5e d868d2e 3133b5e d868d2e 3133b5e d868d2e 3133b5e d868d2e 3133b5e d868d2e 3133b5e d868d2e 3133b5e d868d2e 3133b5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 |
import logging
import math
from typing import Any, Callable, Dict, List, Optional, Union, overload
import numpy as np
from pandas import MultiIndex
from pie_modules.utils import flatten_dict
from pytorch_ie import Document, DocumentMetric
from pytorch_ie.core.metric import T
from pytorch_ie.utils.hydra import resolve_target
from src.hydra_callbacks.save_job_return_value import to_py_obj
logger = logging.getLogger(__name__)
def get_num_total(targets: List[int], preds: List[float]):
return len(targets)
def get_num_positives(targets: List[int], preds: List[float], positive_idx: int = 1):
return len([v for v in targets if v == positive_idx])
@overload
def discretize(values: List[float], threshold: float) -> List[float]: ...
@overload
def discretize(values: List[float], threshold: List[float]) -> Dict[Any, List[float]]: ...
def discretize(
values: List[float], threshold: Union[float, List[float], dict]
) -> Union[List[float], Dict[Any, List[float]]]:
if isinstance(threshold, float):
result = (np.array(values) >= threshold).astype(int).tolist()
return result
if isinstance(threshold, list):
return {t: discretize(values=values, threshold=t) for t in threshold} # type: ignore
if isinstance(threshold, dict):
thresholds = (
np.arange(threshold["start"], threshold["end"], threshold["step"]).round(4).tolist()
)
return discretize(values, threshold=thresholds)
raise TypeError(f"threshold has unknown type: {threshold}")
def get_metric_func(name: str) -> Callable:
if name.endswith("_curve"):
from sklearn.metrics import auc
base_func = resolve_target(name)
def wrapper(targets: List[int], preds: List[float], **kwargs):
x, y, thresholds = base_func(targets, preds, **kwargs)
return auc(y, x)
return wrapper
else:
return resolve_target(name)
def bootstrap(
metric_fn: Callable[[List[int], Union[List[int], List[float]]], float],
targets: List[int],
predictions: Union[List[int], List[float]],
n: int = 1_000,
random_state: int | None = None,
alpha: float = 0.95,
) -> Dict[str, float]:
"""
Returns mean and a two–sided (1–alpha) bootstrap CI for any
pair-wise classification or ranking metric.
Parameters
----------
metric_fn Metric function taking (targets, prediction) lists.
targets Ground-truth 0/1 labels.
prediction Scores or hard predictions (same length as `targets`).
n Number of bootstrap replicates (after skipping degenerate ones).
random_state Seed for reproducibility.
alpha Confidence level (default 0.95 → 95 % CI).
Notes
-----
* A replicate that contains only one class is discarded
because many sklearn metrics are undefined in that case.
* If all replicates are discarded an exception is raised.
"""
y = np.asarray(targets)
yhat = np.asarray(predictions)
if y.shape[0] != yhat.shape[0]:
raise ValueError("`targets` and `prediction` must have the same length")
rng = np.random.default_rng(random_state)
idx = np.arange(y.shape[0])
vals_list: list[float] = []
while len(vals_list) < n:
sample_idx = rng.choice(idx, size=idx.shape[0], replace=True)
y_samp, yhat_samp = y[sample_idx], yhat[sample_idx]
# skip all-positive or all-negative bootstrap samples
if y_samp.min() == y_samp.max():
continue
vals_list.append(metric_fn(y_samp.tolist(), yhat_samp.tolist()))
if not vals_list:
raise RuntimeError("No valid bootstrap replicate contained both classes.")
vals = np.asarray(vals_list, dtype=float)
lower = np.percentile(vals, (1 - alpha) / 2 * 100)
upper = np.percentile(vals, (1 + alpha) / 2 * 100)
return {"mean": float(vals.mean()), "low": float(lower), "high": float(upper)}
class BinaryClassificationMetricsSKLearn(DocumentMetric):
def __init__(
self,
metrics: Dict[str, str],
layer: str,
label: Optional[str] = None,
thresholds: Optional[Dict[str, float]] = None,
default_target_idx: int = 0,
default_prediction_score: float = 0.0,
show_as_markdown: bool = False,
markdown_precision: int = 4,
bootstrap: Optional[list[str]] = None,
bootstrap_n: int = 1_000,
bootstrap_random_state: int | None = None,
bootstrap_alpha: float = 0.95,
create_plots: bool = True,
plots: Optional[Dict[str, str]] = None,
):
self.metrics = {name: get_metric_func(metric) for name, metric in metrics.items()}
self.thresholds = thresholds or {}
thresholds_not_in_metrics = {
name: t for name, t in self.thresholds.items() if name not in self.metrics
}
if len(thresholds_not_in_metrics) > 0:
logger.warning(
f"there are discretizing thresholds that do not have a metric: {thresholds_not_in_metrics}"
)
self.annotation_layer_name = layer
self.annotation_label = label
self.default_target_idx = default_target_idx
self.default_prediction_score = default_prediction_score
self.show_as_markdown = show_as_markdown
self.markdown_precision = markdown_precision
if create_plots:
self.plots = {
name: resolve_target(plot_func) for name, plot_func in (plots or {}).items()
}
else:
self.plots = {}
self.bootstrap = set(bootstrap or [])
self.bootstrap_kwargs = {
"n": bootstrap_n,
"random_state": bootstrap_random_state,
"alpha": bootstrap_alpha,
}
super().__init__()
def reset(self) -> None:
self._preds: List[float] = []
self._targets: List[int] = []
def _update(self, document: Document) -> None:
annotation_layer = document[self.annotation_layer_name]
target2idx = {
ann: int(ann.score)
for ann in annotation_layer
if self.annotation_label is None or ann.label == self.annotation_label
}
prediction2score = {
ann: ann.score
for ann in annotation_layer.predictions
if self.annotation_label is None or ann.label == self.annotation_label
}
all_args = set(target2idx) | set(prediction2score)
all_targets: List[int] = []
all_predictions: List[float] = []
for args in all_args:
target_idx = target2idx.get(args, self.default_target_idx)
prediction_score = prediction2score.get(args, self.default_prediction_score)
all_targets.append(target_idx)
all_predictions.append(prediction_score)
self._preds.extend(all_predictions)
self._targets.extend(all_targets)
def create_plots(self):
from matplotlib import pyplot as plt
# Get the number of metrics
num_plots = len(self.plots)
# Calculate rows and columns for subplots (aim for a square-like layout)
ncols = math.ceil(math.sqrt(num_plots))
nrows = math.ceil(num_plots / ncols)
# Create the subplots
fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 10))
# Flatten the ax_list if necessary (in case of multiple rows/columns)
if num_plots > 1:
ax_list = ax_list.flatten().tolist() # Ensure it's a list, and flatten it if necessary
else:
ax_list = [ax_list]
# Create each plot
for ax, (name, plot_func) in zip(ax_list, self.plots.items()):
# Set the title for each subplot
ax.set_title(name)
plot_func(y_true=self._targets, y_pred=self._preds, ax=ax)
# Adjust layout to avoid overlapping plots
plt.tight_layout()
plt.show()
def _compute(self) -> T:
if len(self.plots) > 0:
self.create_plots()
result = {}
for name, metric in self.metrics.items():
if name in self.thresholds:
preds_dict = discretize(values=self._preds, threshold=self.thresholds[name])
if isinstance(preds_dict, dict):
metric_results = {
t: metric(self._targets, t_preds) for t, t_preds in preds_dict.items()
}
# just get the max
max_t, max_v = max(metric_results.items(), key=lambda k_v: k_v[1])
result[f"{name}_threshold"] = max_t
preds = discretize(values=self._preds, threshold=max_t)
else:
preds = preds_dict
else:
preds = self._preds
if name in self.bootstrap:
# bootstrap the metric
result[name] = bootstrap(
metric_fn=metric,
targets=self._targets,
predictions=preds,
**self.bootstrap_kwargs, # type: ignore
)
else:
result[name] = metric(self._targets, preds)
result = to_py_obj(result)
if self.show_as_markdown:
import pandas as pd
result_flat = flatten_dict(result)
series = pd.Series(result_flat)
if isinstance(series.index, MultiIndex):
if len(series.index.levels) > 1:
# in fact, this is not a series anymore
series = series.unstack(-1)
else:
series.index = series.index.get_level_values(0)
logger.info(
f"{self.current_split}\n{series.round(self.markdown_precision).to_markdown()}"
)
return result
|