RAG_Eval / evaluation /stats /correlation.py
Rom89823974978's picture
Further development
bdb49ae
"""Correlation helpers for RQ1 and RQ2 analyses.
Functions here wrap `scipy.stats` to compute non‑parametric correlations
(Spearman ρ, Kendall τ) with optional bootstrap confidence intervals so
results can be reported with uncertainty estimates.
Typical usage
-------------
>>> from evaluation.stats.correlation import corr_ci
>>> rho, (lo, hi), p = corr_ci(x, y, method="spearman", n_boot=1000)
"""
from __future__ import annotations
from typing import Tuple, Sequence, Literal
import numpy as np
from scipy import stats
Method = Literal["spearman", "kendall"]
def _correlate(x: Sequence[float], y: Sequence[float], method: Method):
if method == "spearman":
return stats.spearmanr(x, y, nan_policy="omit")
if method == "kendall":
return stats.kendalltau(x, y, nan_policy="omit")
raise ValueError(method)
def corr_ci(
x: Sequence[float],
y: Sequence[float],
*,
method: Method = "spearman",
n_boot: int = 1000,
ci: float = 0.95,
random_state: int | None = None,
) -> Tuple[float, Tuple[float, float], float]:
"""Correlation coefficient, bootstrap CI, and p‑value.
Parameters
----------
x, y
Numeric sequences of equal length.
method
'spearman' or 'kendall'.
n_boot
Number of bootstrap resamples for the CI. 0 → no CI.
ci
Confidence level (e.g. 0.95 for 95 %).
random_state
Seed for reproducibility.
Returns
-------
r : float
Correlation coefficient.
(lo, hi) : Tuple[float, float]
Lower/upper CI bounds. ``(nan, nan)`` if *n_boot* == 0.
p : float
Two‑sided p‑value from the correlation test.
"""
x = np.asarray(x, dtype=float)
y = np.asarray(y, dtype=float)
if x.shape != y.shape:
raise ValueError("x and y must have the same length")
r, p = _correlate(x, y, method)
if n_boot == 0:
return float(r), (float("nan"), float("nan")), float(p)
rng = np.random.default_rng(random_state)
bs = []
for _ in range(n_boot):
idx = rng.integers(0, len(x), len(x))
r_bs, _ = _correlate(x[idx], y[idx], method)
bs.append(r_bs)
lo, hi = np.percentile(bs, [(1 - ci) / 2 * 100, (1 + ci) / 2 * 100])
return float(r), (float(lo), float(hi)), float(p)