Spaces:
Running
Running
from .base_benchmark import BaseBenchmark | |
from typing import Dict, Any, Optional, Tuple, List | |
from datasets import load_dataset | |
import re | |
import random | |
from .prompt_templates import get_mmlu_prompt | |
from .evaluation_utils import extract_answer_mmlu | |
class MMLUBenchmark(BaseBenchmark): | |
"""MMLU (Massive Multitask Language Understanding) benchmark""" | |
SUBJECTS = [ | |
'abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', | |
'clinical_knowledge', 'college_biology', 'college_chemistry', | |
'college_computer_science', 'college_mathematics', 'college_medicine', | |
'college_physics', 'computer_security', 'conceptual_physics', | |
'econometrics', 'electrical_engineering', 'elementary_mathematics', | |
'formal_logic', 'global_facts', 'high_school_biology', | |
'high_school_chemistry', 'high_school_computer_science', | |
'high_school_european_history', 'high_school_geography', | |
'high_school_government_and_politics', 'high_school_macroeconomics', | |
'high_school_mathematics', 'high_school_microeconomics', | |
'high_school_physics', 'high_school_psychology', 'high_school_statistics', | |
'high_school_us_history', 'high_school_world_history', 'human_aging', | |
'human_sexuality', 'international_law', 'jurisprudence', | |
'logical_fallacies', 'machine_learning', 'management', 'marketing', | |
'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios', | |
'nutrition', 'philosophy', 'prehistory', 'professional_accounting', | |
'professional_law', 'professional_medicine', 'professional_psychology', | |
'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', | |
'virology', 'world_religions' | |
] | |
def __init__(self): | |
super().__init__(name="MMLU", dataset_name="cais/mmlu") | |
async def load_dataset(self, sample_size: Optional[int] = None, **kwargs): | |
"""Load MMLU dataset""" | |
subjects = kwargs.get('subjects', ['all']) | |
if 'all' in subjects: | |
subjects = self.SUBJECTS | |
else: | |
subjects = [s for s in subjects if s in self.SUBJECTS] | |
self.dataset = [] | |
self.few_shot_examples = {} # Store few-shot examples per subject | |
for subject in subjects: | |
try: | |
# Load dev split for few-shot examples | |
dev_ds = load_dataset(self.dataset_name, subject, split='dev') | |
# Standard MMLU uses 5-shot | |
self.few_shot_examples[subject] = [ | |
{ | |
'question': ex['question'], | |
'choices': ex['choices'], | |
'answer': ex['answer'] | |
} | |
for ex in list(dev_ds)[:5] | |
] | |
# Load test split for evaluation | |
test_ds = load_dataset(self.dataset_name, subject, split='test') | |
for sample in test_ds: | |
self.dataset.append({ | |
'subject': subject, | |
'question': sample['question'], | |
'choices': sample['choices'], | |
'answer': sample['answer'], # 0-3 index | |
'raw_sample': sample | |
}) | |
except Exception as e: | |
print(f"Error loading {subject}: {e}") | |
continue | |
# Shuffle dataset | |
random.shuffle(self.dataset) | |
if sample_size and len(self.dataset) > sample_size: | |
self.dataset = self.dataset[:sample_size] | |
def format_prompt(self, sample: Dict[str, Any]) -> str: | |
"""Format MMLU question as prompt with few-shot examples""" | |
subject = sample['subject'] | |
few_shot_examples = self.few_shot_examples.get(subject, []) | |
return get_mmlu_prompt( | |
question=sample['question'], | |
choices=sample['choices'], | |
subject=subject.replace('_', ' ').title(), | |
few_shot_examples=few_shot_examples | |
) | |
async def evaluate_sample(self, api, sample: Dict[str, Any], **kwargs) -> Tuple[bool, Dict[str, Any]]: | |
"""Evaluate a single MMLU sample""" | |
prompt = self.format_prompt(sample) | |
try: | |
response = await api.generate_with_retry(prompt, **kwargs) | |
# Extract answer from response using standard extraction | |
predicted_letter = extract_answer_mmlu(response) | |
if predicted_letter: | |
predicted_index = ord(predicted_letter) - ord('A') | |
else: | |
# If no clear answer, mark as incorrect | |
predicted_index = -1 | |
correct_index = sample['answer'] | |
is_correct = predicted_index == correct_index | |
result = { | |
'subject': sample['subject'], | |
'question': sample['question'], | |
'choices': sample['choices'], | |
'correct_answer': correct_index, | |
'predicted_answer': predicted_index, | |
'model_response': response, | |
'is_correct': is_correct | |
} | |
return is_correct, result | |
except Exception as e: | |
result = { | |
'subject': sample['subject'], | |
'question': sample['question'], | |
'error': str(e), | |
'is_correct': False | |
} | |
return False, result |