|
import json |
|
import os |
|
from typing import Dict, List, Tuple |
|
|
|
class RAGScoreCalculator: |
|
""" |
|
Dynamic RAG Score calculator that calculates scores at runtime |
|
without modifying the original JSON files. |
|
""" |
|
|
|
def __init__(self, retrieval_dir: str = "result/retrieval"): |
|
self.retrieval_dir = retrieval_dir |
|
self.stats = None |
|
self.all_data = None |
|
self._load_and_analyze() |
|
|
|
def _load_and_analyze(self): |
|
"""Load all retrieval detail files and calculate normalization statistics.""" |
|
self.all_data = [] |
|
detail_files = [f for f in os.listdir(self.retrieval_dir) if f.startswith('detail_')] |
|
|
|
if not detail_files: |
|
print("Warning: No detail files found in retrieval directory") |
|
return |
|
|
|
for filename in detail_files: |
|
filepath = os.path.join(self.retrieval_dir, filename) |
|
try: |
|
with open(filepath, 'r') as f: |
|
data = json.load(f) |
|
self.all_data.append(data) |
|
except Exception as e: |
|
print(f"Error loading {filename}: {e}") |
|
continue |
|
|
|
if not self.all_data: |
|
print("Warning: No valid data loaded from detail files") |
|
return |
|
|
|
|
|
self._calculate_stats() |
|
|
|
def _calculate_stats(self): |
|
"""Calculate min/max statistics for normalization.""" |
|
if not self.all_data: |
|
return |
|
|
|
|
|
rag_success_rates = [d.get('RAG_success_rate', 0) for d in self.all_data] |
|
max_correct_refs = [d.get('max_correct_references', 0) for d in self.all_data] |
|
false_positives = [d.get('total_false_positives', 0) for d in self.all_data] |
|
missed_refs = [d.get('total_missed_references', 0) for d in self.all_data] |
|
|
|
|
|
self.stats = { |
|
'rag_success_rate': { |
|
'min': min(rag_success_rates), |
|
'max': max(rag_success_rates) |
|
}, |
|
'max_correct_references': { |
|
'min': min(max_correct_refs), |
|
'max': max(max_correct_refs) |
|
}, |
|
'total_false_positives': { |
|
'min': min(false_positives), |
|
'max': max(false_positives) |
|
}, |
|
'total_missed_references': { |
|
'min': 0, |
|
'max': 7114 |
|
} |
|
} |
|
|
|
def normalize_value(self, value, min_val, max_val, higher_is_better=True): |
|
"""Normalize a value to 0-1 range.""" |
|
if max_val == min_val: |
|
return 1.0 |
|
|
|
normalized = (value - min_val) / (max_val - min_val) |
|
|
|
if not higher_is_better: |
|
normalized = 1 - normalized |
|
|
|
return normalized |
|
|
|
def calculate_rag_score(self, data: Dict) -> float: |
|
"""Calculate the RAG score for a single model's data.""" |
|
if not self.stats: |
|
print("Warning: No statistics available for normalization") |
|
return 0.0 |
|
|
|
|
|
rag_success_rate = data.get('RAG_success_rate', 0) |
|
max_correct_refs = data.get('max_correct_references', 0) |
|
false_positives = data.get('total_false_positives', 0) |
|
missed_refs = data.get('total_missed_references', 0) |
|
|
|
|
|
norm_max_correct = self.normalize_value( |
|
max_correct_refs, |
|
self.stats['max_correct_references']['min'], |
|
self.stats['max_correct_references']['max'], |
|
higher_is_better=True |
|
) |
|
|
|
norm_false_positives = self.normalize_value( |
|
false_positives, |
|
self.stats['total_false_positives']['min'], |
|
self.stats['total_false_positives']['max'], |
|
higher_is_better=False |
|
) |
|
|
|
norm_missed_refs = self.normalize_value( |
|
missed_refs, |
|
self.stats['total_missed_references']['min'], |
|
self.stats['total_missed_references']['max'], |
|
higher_is_better=False |
|
) |
|
|
|
|
|
|
|
rag_score = ( |
|
0.9 * rag_success_rate + |
|
0.9 * norm_false_positives + |
|
0.1 * norm_max_correct + |
|
0.1 * norm_missed_refs |
|
) / 2.0 |
|
|
|
return round(rag_score, 4) |
|
|
|
def get_normalization_info(self) -> Dict: |
|
"""Get current normalization statistics for debugging.""" |
|
return { |
|
'stats': self.stats, |
|
'total_files': len(self.all_data) if self.all_data else 0, |
|
'retrieval_dir': self.retrieval_dir |
|
} |
|
|
|
def refresh_stats(self): |
|
"""Refresh statistics by reloading data - call this when new data is added.""" |
|
print("Refreshing RAG Score normalization statistics...") |
|
self._load_and_analyze() |
|
return self.stats is not None |
|
|
|
def main(): |
|
"""Main function for testing RAG score calculations.""" |
|
calculator = RAGScoreCalculator() |
|
|
|
print("RAG Score Calculator (Runtime Only)") |
|
print("===================================") |
|
|
|
|
|
info = calculator.get_normalization_info() |
|
print(f"Total files: {info['total_files']}") |
|
print(f"Retrieval directory: {info['retrieval_dir']}") |
|
|
|
if info['stats']: |
|
print("\nNormalization ranges:") |
|
for metric, data in info['stats'].items(): |
|
print(f" {metric}: {data['min']} - {data['max']}") |
|
|
|
print("\nSample RAG Score calculations:") |
|
for i, data in enumerate(calculator.all_data[:5]): |
|
rag_score = calculator.calculate_rag_score(data) |
|
model_name = data.get('model_name', 'Unknown') |
|
print(f" {model_name}: {rag_score}") |
|
else: |
|
print("\n❌ No statistics available for normalization") |
|
|
|
if __name__ == "__main__": |
|
main() |