Mezura / utils /rag_score_calculator.py
nmmursit's picture
Upload 5 files
8c404fc verified
import json
import os
from typing import Dict, List, Tuple
class RAGScoreCalculator:
"""
Dynamic RAG Score calculator that calculates scores at runtime
without modifying the original JSON files.
"""
def __init__(self, retrieval_dir: str = "result/retrieval"):
self.retrieval_dir = retrieval_dir
self.stats = None
self.all_data = None
self._load_and_analyze()
def _load_and_analyze(self):
"""Load all retrieval detail files and calculate normalization statistics."""
self.all_data = []
detail_files = [f for f in os.listdir(self.retrieval_dir) if f.startswith('detail_')]
if not detail_files:
print("Warning: No detail files found in retrieval directory")
return
for filename in detail_files:
filepath = os.path.join(self.retrieval_dir, filename)
try:
with open(filepath, 'r') as f:
data = json.load(f)
self.all_data.append(data)
except Exception as e:
print(f"Error loading {filename}: {e}")
continue
if not self.all_data:
print("Warning: No valid data loaded from detail files")
return
# Calculate normalization statistics
self._calculate_stats()
def _calculate_stats(self):
"""Calculate min/max statistics for normalization."""
if not self.all_data:
return
# Extract values for analysis
rag_success_rates = [d.get('RAG_success_rate', 0) for d in self.all_data]
max_correct_refs = [d.get('max_correct_references', 0) for d in self.all_data]
false_positives = [d.get('total_false_positives', 0) for d in self.all_data]
missed_refs = [d.get('total_missed_references', 0) for d in self.all_data]
# Calculate min/max for normalization
self.stats = {
'rag_success_rate': {
'min': min(rag_success_rates),
'max': max(rag_success_rates)
},
'max_correct_references': {
'min': min(max_correct_refs),
'max': max(max_correct_refs)
},
'total_false_positives': {
'min': min(false_positives),
'max': max(false_positives)
},
'total_missed_references': {
'min': 0, # Fixed minimum value
'max': 7114 # Fixed maximum value
}
}
def normalize_value(self, value, min_val, max_val, higher_is_better=True):
"""Normalize a value to 0-1 range."""
if max_val == min_val:
return 1.0 # If all values are the same, return 1
normalized = (value - min_val) / (max_val - min_val)
if not higher_is_better:
normalized = 1 - normalized # Flip for "lower is better" metrics
return normalized
def calculate_rag_score(self, data: Dict) -> float:
"""Calculate the RAG score for a single model's data."""
if not self.stats:
print("Warning: No statistics available for normalization")
return 0.0
# Extract values with defaults
rag_success_rate = data.get('RAG_success_rate', 0)
max_correct_refs = data.get('max_correct_references', 0)
false_positives = data.get('total_false_positives', 0)
missed_refs = data.get('total_missed_references', 0)
# Normalize values (0-1)
norm_max_correct = self.normalize_value(
max_correct_refs,
self.stats['max_correct_references']['min'],
self.stats['max_correct_references']['max'],
higher_is_better=True
)
norm_false_positives = self.normalize_value(
false_positives,
self.stats['total_false_positives']['min'],
self.stats['total_false_positives']['max'],
higher_is_better=False # Lower is better
)
norm_missed_refs = self.normalize_value(
missed_refs,
self.stats['total_missed_references']['min'],
self.stats['total_missed_references']['max'],
higher_is_better=False # Lower is better
)
# Calculate weighted score
# Weights: rag_success_rate=0.9, false_positives=0.9, max_correct=0.1, missed_refs=0.1
rag_score = (
0.9 * rag_success_rate +
0.9 * norm_false_positives +
0.1 * norm_max_correct +
0.1 * norm_missed_refs
) / 2.0 # Divide by 2 since total weights = 2.0
return round(rag_score, 4)
def get_normalization_info(self) -> Dict:
"""Get current normalization statistics for debugging."""
return {
'stats': self.stats,
'total_files': len(self.all_data) if self.all_data else 0,
'retrieval_dir': self.retrieval_dir
}
def refresh_stats(self):
"""Refresh statistics by reloading data - call this when new data is added."""
print("Refreshing RAG Score normalization statistics...")
self._load_and_analyze()
return self.stats is not None
def main():
"""Main function for testing RAG score calculations."""
calculator = RAGScoreCalculator()
print("RAG Score Calculator (Runtime Only)")
print("===================================")
# Show normalization info
info = calculator.get_normalization_info()
print(f"Total files: {info['total_files']}")
print(f"Retrieval directory: {info['retrieval_dir']}")
if info['stats']:
print("\nNormalization ranges:")
for metric, data in info['stats'].items():
print(f" {metric}: {data['min']} - {data['max']}")
print("\nSample RAG Score calculations:")
for i, data in enumerate(calculator.all_data[:5]): # Show first 5
rag_score = calculator.calculate_rag_score(data)
model_name = data.get('model_name', 'Unknown')
print(f" {model_name}: {rag_score}")
else:
print("\n❌ No statistics available for normalization")
if __name__ == "__main__":
main()