File size: 17,207 Bytes
9f5e57c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
"""
Confidence calibration framework for RAG systems based on research best practices.

Implements Expected Calibration Error (ECE), Adaptive Calibration Error (ACE),
temperature scaling, and reliability diagrams for proper confidence calibration.

References:
- Guo et al. "On Calibration of Modern Neural Networks" (2017)
- Kumar et al. "Verified Uncertainty Calibration" (2019)
- RAG-specific calibration research (2024)
"""

import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Optional, Any
from dataclasses import dataclass
from sklearn.calibration import calibration_curve
from sklearn.metrics import brier_score_loss
import json
from pathlib import Path


@dataclass
class CalibrationMetrics:
    """Container for calibration evaluation metrics."""
    ece: float  # Expected Calibration Error
    ace: float  # Adaptive Calibration Error  
    mce: float  # Maximum Calibration Error
    brier_score: float  # Brier Score
    negative_log_likelihood: float  # Negative Log Likelihood
    reliability_diagram_data: Dict[str, List[float]]


@dataclass
class CalibrationDataPoint:
    """Single data point for calibration evaluation."""
    predicted_confidence: float
    actual_correctness: float  # 0.0 or 1.0
    query: str
    answer: str
    context_relevance: float
    metadata: Dict[str, Any]


class ConfidenceCalibrator:
    """
    Implements temperature scaling and other calibration methods for RAG systems.
    
    Based on research best practices for confidence calibration in QA systems.
    """
    
    def __init__(self):
        self.temperature: Optional[float] = None
        self.is_fitted = False
        
    def fit_temperature_scaling(
        self, 
        confidences: List[float], 
        correctness: List[float]
    ) -> float:
        """
        Fit temperature scaling parameter using validation data.
        
        Args:
            confidences: Predicted confidence scores
            correctness: Ground truth correctness (0.0 or 1.0)
            
        Returns:
            Optimal temperature parameter
        """
        from scipy.optimize import minimize_scalar
        
        # Create temporary evaluator for ECE computation
        evaluator = CalibrationEvaluator()
        
        def temperature_objective(temp: float) -> float:
            """Objective function for temperature scaling optimization."""
            calibrated_confidences = self._apply_temperature_scaling(confidences, temp)
            return evaluator._compute_ece(calibrated_confidences, correctness)
        
        # Find optimal temperature
        result = minimize_scalar(temperature_objective, bounds=(0.1, 5.0), method='bounded')
        self.temperature = result.x
        self.is_fitted = True
        
        return self.temperature
    
    def _apply_temperature_scaling(
        self, 
        confidences: List[float], 
        temperature: float
    ) -> List[float]:
        """Apply temperature scaling to confidence scores."""
        # Convert to logits, apply temperature, convert back to probabilities
        confidences = np.array(confidences)
        # Avoid log(0) and log(1)
        confidences = np.clip(confidences, 1e-8, 1 - 1e-8)
        
        logits = np.log(confidences / (1 - confidences))
        scaled_logits = logits / temperature
        scaled_confidences = 1 / (1 + np.exp(-scaled_logits))
        
        return scaled_confidences.tolist()
    
    def calibrate_confidence(self, confidence: float) -> float:
        """
        Apply fitted temperature scaling to a single confidence score.
        
        Args:
            confidence: Raw confidence score
            
        Returns:
            Calibrated confidence score
        """
        if not self.is_fitted:
            raise ValueError("Calibrator must be fitted before use")
        
        return self._apply_temperature_scaling([confidence], self.temperature)[0]


class CalibrationEvaluator:
    """
    Evaluates confidence calibration using standard metrics.
    
    Implements ECE, ACE, MCE, Brier Score, and reliability diagrams.
    """
    
    def __init__(self, n_bins: int = 10):
        self.n_bins = n_bins
    
    def evaluate_calibration(
        self, 
        data_points: List[CalibrationDataPoint]
    ) -> CalibrationMetrics:
        """
        Compute comprehensive calibration metrics.
        
        Args:
            data_points: List of calibration data points
            
        Returns:
            CalibrationMetrics with all computed metrics
        """
        confidences = [dp.predicted_confidence for dp in data_points]
        correctness = [dp.actual_correctness for dp in data_points]
        
        # Compute all metrics
        ece = self._compute_ece(confidences, correctness)
        ace = self._compute_ace(confidences, correctness)
        mce = self._compute_mce(confidences, correctness)
        brier = brier_score_loss(correctness, confidences)
        nll = self._compute_nll(confidences, correctness)
        reliability_data = self._compute_reliability_diagram_data(confidences, correctness)
        
        return CalibrationMetrics(
            ece=ece,
            ace=ace,
            mce=mce,
            brier_score=brier,
            negative_log_likelihood=nll,
            reliability_diagram_data=reliability_data
        )
    
    def _compute_ece(self, confidences: List[float], correctness: List[float]) -> float:
        """
        Compute Expected Calibration Error (ECE).
        
        ECE measures the difference between confidence and accuracy across bins.
        """
        confidences = np.array(confidences)
        correctness = np.array(correctness)
        
        bin_boundaries = np.linspace(0, 1, self.n_bins + 1)
        bin_lowers = bin_boundaries[:-1]
        bin_uppers = bin_boundaries[1:]
        
        ece = 0.0
        for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
            # Find samples in this bin
            in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
            prop_in_bin = in_bin.mean()
            
            if prop_in_bin > 0:
                accuracy_in_bin = correctness[in_bin].mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
        
        return ece
    
    def _compute_ace(self, confidences: List[float], correctness: List[float]) -> float:
        """
        Compute Adaptive Calibration Error (ACE).
        
        ACE addresses binning bias by using equal-mass bins.
        """
        confidences = np.array(confidences)
        correctness = np.array(correctness)
        
        # Sort by confidence
        indices = np.argsort(confidences)
        sorted_confidences = confidences[indices]
        sorted_correctness = correctness[indices]
        
        n_samples = len(confidences)
        bin_size = n_samples // self.n_bins
        
        ace = 0.0
        for i in range(self.n_bins):
            start_idx = i * bin_size
            end_idx = (i + 1) * bin_size if i < self.n_bins - 1 else n_samples
            
            bin_confidences = sorted_confidences[start_idx:end_idx]
            bin_correctness = sorted_correctness[start_idx:end_idx]
            
            if len(bin_confidences) > 0:
                avg_confidence = bin_confidences.mean()
                accuracy = bin_correctness.mean()
                bin_weight = len(bin_confidences) / n_samples
                ace += np.abs(avg_confidence - accuracy) * bin_weight
        
        return ace
    
    def _compute_mce(self, confidences: List[float], correctness: List[float]) -> float:
        """
        Compute Maximum Calibration Error (MCE).
        
        MCE is the maximum difference between confidence and accuracy across bins.
        """
        confidences = np.array(confidences)
        correctness = np.array(correctness)
        
        bin_boundaries = np.linspace(0, 1, self.n_bins + 1)
        bin_lowers = bin_boundaries[:-1]
        bin_uppers = bin_boundaries[1:]
        
        max_error = 0.0
        for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
            in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
            
            if in_bin.sum() > 0:
                accuracy_in_bin = correctness[in_bin].mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                error = np.abs(avg_confidence_in_bin - accuracy_in_bin)
                max_error = max(max_error, error)
        
        return max_error
    
    def _compute_nll(self, confidences: List[float], correctness: List[float]) -> float:
        """Compute Negative Log Likelihood."""
        confidences = np.array(confidences)
        correctness = np.array(correctness)
        
        # Avoid log(0)
        confidences = np.clip(confidences, 1e-8, 1 - 1e-8)
        
        # For binary classification: NLL = -Σ[y*log(p) + (1-y)*log(1-p)]
        nll = -(correctness * np.log(confidences) + 
                (1 - correctness) * np.log(1 - confidences)).mean()
        
        return nll
    
    def _compute_reliability_diagram_data(
        self, 
        confidences: List[float], 
        correctness: List[float]
    ) -> Dict[str, List[float]]:
        """Compute data for reliability diagram visualization."""
        confidences = np.array(confidences)
        correctness = np.array(correctness)
        
        bin_boundaries = np.linspace(0, 1, self.n_bins + 1)
        bin_centers = (bin_boundaries[:-1] + bin_boundaries[1:]) / 2
        
        bin_confidences = []
        bin_accuracies = []
        bin_counts = []
        
        for i in range(self.n_bins):
            bin_lower = bin_boundaries[i]
            bin_upper = bin_boundaries[i + 1]
            
            in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
            count = in_bin.sum()
            
            if count > 0:
                avg_confidence = confidences[in_bin].mean()
                accuracy = correctness[in_bin].mean()
            else:
                avg_confidence = bin_centers[i]
                accuracy = 0.0
            
            bin_confidences.append(avg_confidence)
            bin_accuracies.append(accuracy)
            bin_counts.append(count)
        
        return {
            "bin_centers": bin_centers.tolist(),
            "bin_confidences": bin_confidences,
            "bin_accuracies": bin_accuracies,
            "bin_counts": bin_counts
        }
    
    def plot_reliability_diagram(
        self, 
        metrics: CalibrationMetrics, 
        save_path: Optional[Path] = None
    ) -> None:
        """
        Create and optionally save a reliability diagram.
        
        Args:
            metrics: CalibrationMetrics containing reliability data
            save_path: Optional path to save the plot
        """
        data = metrics.reliability_diagram_data
        
        fig, ax = plt.subplots(figsize=(8, 6))
        
        # Plot reliability line (perfect calibration)
        ax.plot([0, 1], [0, 1], 'k--', alpha=0.7, label='Perfect calibration')
        
        # Plot actual calibration
        ax.bar(
            data["bin_centers"], 
            data["bin_accuracies"],
            width=0.08,
            alpha=0.7,
            edgecolor='black',
            label='Model calibration'
        )
        
        # Plot gap between confidence and accuracy
        for center, conf, acc in zip(
            data["bin_centers"], 
            data["bin_confidences"], 
            data["bin_accuracies"]
        ):
            if conf != acc:
                ax.plot([center, center], [acc, conf], 'r-', alpha=0.8, linewidth=2)
        
        ax.set_xlabel('Confidence')
        ax.set_ylabel('Accuracy')
        ax.set_title(f'Reliability Diagram (ECE: {metrics.ece:.3f})')
        ax.legend()
        ax.grid(True, alpha=0.3)
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            plt.close()
        else:
            plt.show()


def create_evaluation_dataset_from_test_results(
    test_results: List[Dict[str, Any]]
) -> List[CalibrationDataPoint]:
    """
    Convert test results into calibration evaluation dataset.
    
    Args:
        test_results: List of test result dictionaries
        
    Returns:
        List of CalibrationDataPoint objects
    """
    data_points = []
    
    for result in test_results:
        # Extract correctness (this would need domain-specific logic)
        # For now, use a simple heuristic based on answer quality
        correctness = _assess_answer_correctness(result)
        
        data_point = CalibrationDataPoint(
            predicted_confidence=result.get('confidence', 0.0),
            actual_correctness=correctness,
            query=result.get('query', ''),
            answer=result.get('answer', ''),
            context_relevance=_compute_context_relevance(result),
            metadata={
                'model_used': result.get('model_used', ''),
                'retrieval_method': result.get('retrieval_method', ''),
                'num_citations': len(result.get('citations', []))
            }
        )
        data_points.append(data_point)
    
    return data_points


def _assess_answer_correctness(result: Dict[str, Any]) -> float:
    """
    Assess answer correctness for calibration evaluation.
    
    This is a simplified heuristic - in practice, this should be:
    1. Human evaluation
    2. Automated fact-checking against ground truth
    3. Domain-specific quality metrics
    """
    answer = result.get('answer', '').lower()
    citations = result.get('citations', [])
    
    # Simple heuristic: consider correct if has citations and no uncertainty
    uncertainty_phrases = [
        'cannot answer', 'not contained', 'no relevant', 
        'insufficient information', 'unclear', 'not specified'
    ]
    
    has_uncertainty = any(phrase in answer for phrase in uncertainty_phrases)
    has_citations = len(citations) > 0
    
    if has_uncertainty:
        return 0.0  # Explicit uncertainty = incorrect/no answer
    elif has_citations and len(answer.split()) > 10:
        return 1.0  # Has citations and substantial answer = likely correct
    else:
        return 0.5  # Partial credit for borderline cases


def _compute_context_relevance(result: Dict[str, Any]) -> float:
    """Compute average relevance of retrieved context."""
    citations = result.get('citations', [])
    if not citations:
        return 0.0
    
    relevances = [citation.get('relevance', 0.0) for citation in citations]
    return sum(relevances) / len(relevances)


if __name__ == "__main__":
    # Example usage and testing
    print("Testing confidence calibration framework...")
    
    # Create mock data for testing
    np.random.seed(42)
    n_samples = 100
    
    # Simulate miscalibrated confidence scores (too high)
    true_correctness = np.random.binomial(1, 0.6, n_samples)
    predicted_confidence = np.random.beta(8, 3, n_samples)  # Overconfident
    
    # Test calibration evaluation
    evaluator = CalibrationEvaluator()
    data_points = [
        CalibrationDataPoint(
            predicted_confidence=conf,
            actual_correctness=float(correct),
            query=f"query_{i}",
            answer=f"answer_{i}",
            context_relevance=0.7,
            metadata={}
        )
        for i, (conf, correct) in enumerate(zip(predicted_confidence, true_correctness))
    ]
    
    metrics = evaluator.evaluate_calibration(data_points)
    
    print(f"Before calibration:")
    print(f"  ECE: {metrics.ece:.3f}")
    print(f"  ACE: {metrics.ace:.3f}")
    print(f"  MCE: {metrics.mce:.3f}")
    print(f"  Brier Score: {metrics.brier_score:.3f}")
    
    # Test temperature scaling
    calibrator = ConfidenceCalibrator()
    optimal_temp = calibrator.fit_temperature_scaling(
        predicted_confidence.tolist(), 
        true_correctness.tolist()
    )
    
    print(f"  Optimal temperature: {optimal_temp:.3f}")
    
    # Apply calibration
    calibrated_confidences = [
        calibrator.calibrate_confidence(conf) for conf in predicted_confidence
    ]
    
    # Re-evaluate
    calibrated_data_points = [
        CalibrationDataPoint(
            predicted_confidence=conf,
            actual_correctness=float(correct),
            query=f"query_{i}",
            answer=f"answer_{i}",
            context_relevance=0.7,
            metadata={}
        )
        for i, (conf, correct) in enumerate(zip(calibrated_confidences, true_correctness))
    ]
    
    calibrated_metrics = evaluator.evaluate_calibration(calibrated_data_points)
    
    print(f"\nAfter temperature scaling:")
    print(f"  ECE: {calibrated_metrics.ece:.3f}")
    print(f"  ACE: {calibrated_metrics.ace:.3f}")
    print(f"  MCE: {calibrated_metrics.mce:.3f}")
    print(f"  Brier Score: {calibrated_metrics.brier_score:.3f}")
    
    print("\n✅ Calibration framework working correctly!")